From 594aad3063f31fd230d2f4dec8c6e15a9e128f8c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Tue, 26 Nov 2024 11:08:53 +0100 Subject: [PATCH] refactor: Use heartbeat date ranges to track progress (#26094) --- mypy-baseline.txt | 5 - .../temporal/batch_exports/batch_exports.py | 107 +++++- .../batch_exports/bigquery_batch_export.py | 38 +- posthog/temporal/batch_exports/heartbeat.py | 215 ++++++++++++ .../batch_exports/redshift_batch_export.py | 62 +++- .../temporal/batch_exports/s3_batch_export.py | 112 +++--- .../batch_exports/snowflake_batch_export.py | 56 +-- .../temporal/batch_exports/temporary_file.py | 49 ++- posthog/temporal/common/heartbeat.py | 122 ++++++- posthog/temporal/common/utils.py | 149 -------- .../tests/batch_exports/test_batch_exports.py | 149 +++++++- .../test_bigquery_batch_export_workflow.py | 326 +++++++++++++++--- .../tests/batch_exports/test_heartbeat.py | 104 ++++++ .../test_redshift_batch_export_workflow.py | 318 +++++++++++++++-- .../test_s3_batch_export_workflow.py | 30 +- .../test_snowflake_batch_export_workflow.py | 22 +- .../batch_exports/test_temporary_file.py | 64 ++-- posthog/temporal/tests/batch_exports/utils.py | 1 - posthog/temporal/tests/utils/events.py | 82 +++-- 19 files changed, 1575 insertions(+), 436 deletions(-) create mode 100644 posthog/temporal/batch_exports/heartbeat.py delete mode 100644 posthog/temporal/common/utils.py create mode 100644 posthog/temporal/tests/batch_exports/test_heartbeat.py diff --git a/mypy-baseline.txt b/mypy-baseline.txt index 5eb6a53922ee2..c7017447c0f92 100644 --- a/mypy-baseline.txt +++ b/mypy-baseline.txt @@ -1,6 +1,3 @@ -posthog/temporal/common/utils.py:0: error: Argument 1 to "abstractclassmethod" has incompatible type "Callable[[HeartbeatDetails, Any], Any]"; expected "Callable[[type[Never], Any], Any]" [arg-type] -posthog/temporal/common/utils.py:0: note: This is likely because "from_activity" has named arguments: "cls". Consider marking them positional-only -posthog/temporal/common/utils.py:0: error: Argument 2 to "__get__" of "classmethod" has incompatible type "type[HeartbeatType]"; expected "type[Never]" [arg-type] posthog/tasks/exports/ordered_csv_renderer.py:0: error: No return value expected [return-value] posthog/warehouse/models/ssh_tunnel.py:0: error: Incompatible types in assignment (expression has type "NoEncryption", variable has type "BestAvailableEncryption") [assignment] posthog/temporal/data_imports/pipelines/sql_database_v2/schema_types.py:0: error: Statement is unreachable [unreachable] @@ -829,8 +826,6 @@ posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py:0: posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py:0: error: Need type annotation for "_execute_async_calls" (hint: "_execute_async_calls: list[] = ...") [var-annotated] posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py:0: error: Need type annotation for "_cursors" (hint: "_cursors: list[] = ...") [var-annotated] posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py:0: error: List item 0 has incompatible type "tuple[str, str, int, int, int, int, str, int]"; expected "tuple[str, str, int, int, str, str, str, str]" [list-item] -posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py:0: error: "tuple[Any, ...]" has no attribute "last_uploaded_part_timestamp" [attr-defined] -posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py:0: error: "tuple[Any, ...]" has no attribute "upload_state" [attr-defined] posthog/migrations/0237_remove_timezone_from_teams.py:0: error: Argument 2 to "RunPython" has incompatible type "Callable[[Migration, Any], None]"; expected "_CodeCallable | None" [arg-type] posthog/migrations/0228_fix_tile_layouts.py:0: error: Argument 2 to "RunPython" has incompatible type "Callable[[Migration, Any], None]"; expected "_CodeCallable | None" [arg-type] posthog/api/plugin_log_entry.py:0: error: Name "timezone.datetime" is not defined [name-defined] diff --git a/posthog/temporal/batch_exports/batch_exports.py b/posthog/temporal/batch_exports/batch_exports.py index f980c91dc25ed..fd9a718667262 100644 --- a/posthog/temporal/batch_exports/batch_exports.py +++ b/posthog/temporal/batch_exports/batch_exports.py @@ -3,6 +3,7 @@ import collections.abc import dataclasses import datetime as dt +import operator import typing import uuid from string import Template @@ -361,8 +362,8 @@ def start_produce_batch_export_record_batches( model_name: str, is_backfill: bool, team_id: int, - interval_start: str | None, - interval_end: str, + full_range: tuple[dt.datetime | None, dt.datetime], + done_ranges: list[tuple[dt.datetime, dt.datetime]], fields: list[BatchExportField] | None = None, destination_default_fields: list[BatchExportField] | None = None, **parameters, @@ -386,7 +387,7 @@ def start_produce_batch_export_record_batches( fields = destination_default_fields if model_name == "persons": - if is_backfill and interval_start is None: + if is_backfill and full_range[0] is None: view = SELECT_FROM_PERSONS_VIEW_BACKFILL else: view = SELECT_FROM_PERSONS_VIEW @@ -420,26 +421,112 @@ def start_produce_batch_export_record_batches( view = query_template.substitute(fields=query_fields) - if interval_start is not None: - parameters["interval_start"] = dt.datetime.fromisoformat(interval_start).strftime("%Y-%m-%d %H:%M:%S") - - parameters["interval_end"] = dt.datetime.fromisoformat(interval_end).strftime("%Y-%m-%d %H:%M:%S") parameters["team_id"] = team_id extra_query_parameters = parameters.pop("extra_query_parameters", {}) or {} parameters = {**parameters, **extra_query_parameters} queue = RecordBatchQueue(max_size_bytes=settings.BATCH_EXPORT_BUFFER_QUEUE_MAX_SIZE_BYTES) - query_id = uuid.uuid4() produce_task = asyncio.create_task( - client.aproduce_query_as_arrow_record_batches( - view, queue=queue, query_parameters=parameters, query_id=str(query_id) + produce_batch_export_record_batches_from_range( + client=client, + query=view, + full_range=full_range, + done_ranges=done_ranges, + queue=queue, + query_parameters=parameters, ) ) return queue, produce_task +async def produce_batch_export_record_batches_from_range( + client: ClickHouseClient, + query: str, + full_range: tuple[dt.datetime | None, dt.datetime], + done_ranges: collections.abc.Sequence[tuple[dt.datetime, dt.datetime]], + queue: RecordBatchQueue, + query_parameters: dict[str, typing.Any], +): + """Produce all record batches into `queue` required to complete `full_range`. + + This function will skip over any already completed `done_ranges`. + """ + for interval_start, interval_end in generate_query_ranges(full_range, done_ranges): + if interval_start is not None: + query_parameters["interval_start"] = interval_start.strftime("%Y-%m-%d %H:%M:%S.%f") + query_parameters["interval_end"] = interval_end.strftime("%Y-%m-%d %H:%M:%S.%f") + query_id = uuid.uuid4() + + await client.aproduce_query_as_arrow_record_batches( + query, queue=queue, query_parameters=query_parameters, query_id=str(query_id) + ) + + +def generate_query_ranges( + remaining_range: tuple[dt.datetime | None, dt.datetime], + done_ranges: collections.abc.Sequence[tuple[dt.datetime, dt.datetime]], +) -> typing.Iterator[tuple[dt.datetime | None, dt.datetime]]: + """Recursively yield ranges of dates that need to be queried. + + There are essentially 3 scenarios we are expecting: + 1. The batch export just started, so we expect `done_ranges` to be an empty + list, and thus should return the `remaining_range`. + 2. The batch export crashed mid-execution, so we have some `done_ranges` that + do not completely add up to the full range. In this case we need to yield + ranges in between all the done ones. + 3. The batch export crashed right after we finish, so we have a full list of + `done_ranges` adding up to the `remaining_range`. In this case we should not + yield anything. + + Case 1 is fairly trivial and we can simply return `remaining_range` if we get + an empty `done_ranges`. + + Case 2 is more complicated and we can expect that the ranges produced by this + function will lead to duplicate events selected, as our batch export query is + inclusive in the lower bound. Since multiple rows may have the same + `inserted_at` we cannot simply skip an `inserted_at` value, as there may be a + row that hasn't been exported as it with the same `inserted_at` as a row that + has been exported. So this function will return ranges with `inserted_at` + values that were already exported for at least one event. Ideally, this is + *only* one event, but we can never be certain. + """ + if len(done_ranges) == 0: + yield remaining_range + return + + epoch = dt.datetime.fromtimestamp(0, tz=dt.UTC) + list_done_ranges: list[tuple[dt.datetime, dt.datetime]] = list(done_ranges) + + list_done_ranges.sort(key=operator.itemgetter(0)) + + while True: + try: + next_range: tuple[dt.datetime | None, dt.datetime] = list_done_ranges.pop(0) + except IndexError: + if remaining_range[0] != remaining_range[1]: + # If they were equal it would mean we have finished. + yield remaining_range + + return + else: + candidate_end_at = next_range[0] if next_range[0] is not None else epoch + + candidate_start_at = remaining_range[0] + remaining_range = (next_range[1], remaining_range[1]) + + if candidate_start_at is not None and candidate_start_at >= candidate_end_at: + # We have landed within a done range. + continue + + if candidate_start_at is None and candidate_end_at == epoch: + # We have landed within the first done range of a backfill. + continue + + yield (candidate_start_at, candidate_end_at) + + async def raise_on_produce_task_failure(produce_task: asyncio.Task) -> None: """Raise `RecordBatchProducerError` if a produce task failed. diff --git a/posthog/temporal/batch_exports/bigquery_batch_export.py b/posthog/temporal/batch_exports/bigquery_batch_export.py index 6b5fe3beb42fe..e4eea3625a7fc 100644 --- a/posthog/temporal/batch_exports/bigquery_batch_export.py +++ b/posthog/temporal/batch_exports/bigquery_batch_export.py @@ -49,13 +49,15 @@ cast_record_batch_json_columns, set_status_to_running_task, ) +from posthog.temporal.batch_exports.heartbeat import ( + BatchExportRangeHeartbeatDetails, + DateRange, + should_resume_from_activity_heartbeat, +) + from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.heartbeat import Heartbeater from posthog.temporal.common.logger import configure_temporal_worker_logger -from posthog.temporal.common.utils import ( - BatchExportHeartbeatDetails, - should_resume_from_activity_heartbeat, -) logger = structlog.get_logger() @@ -113,7 +115,7 @@ def get_bigquery_fields_from_record_schema( @dataclasses.dataclass -class BigQueryHeartbeatDetails(BatchExportHeartbeatDetails): +class BigQueryHeartbeatDetails(BatchExportRangeHeartbeatDetails): """The BigQuery batch export details included in every heartbeat.""" pass @@ -366,12 +368,11 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records if not await client.is_alive(): raise ConnectionError("Cannot establish connection to ClickHouse") - should_resume, details = await should_resume_from_activity_heartbeat(activity, BigQueryHeartbeatDetails, logger) + _, details = await should_resume_from_activity_heartbeat(activity, BigQueryHeartbeatDetails) + if details is None: + details = BigQueryHeartbeatDetails() - if should_resume is True and details is not None: - data_interval_start: str | None = details.last_inserted_at.isoformat() - else: - data_interval_start = inputs.data_interval_start + done_ranges: list[DateRange] = details.done_ranges model: BatchExportModel | BatchExportSchema | None = None if inputs.batch_export_schema is None and "batch_export_model" in { @@ -392,13 +393,18 @@ async def insert_into_bigquery_activity(inputs: BigQueryInsertInputs) -> Records extra_query_parameters = model["values"] if model is not None else {} fields = model["fields"] if model is not None else None + data_interval_start = ( + dt.datetime.fromisoformat(inputs.data_interval_start) if inputs.data_interval_start else None + ) + data_interval_end = dt.datetime.fromisoformat(inputs.data_interval_end) + full_range = (data_interval_start, data_interval_end) queue, produce_task = start_produce_batch_export_record_batches( client=client, model_name=model_name, is_backfill=inputs.is_backfill, team_id=inputs.team_id, - interval_start=data_interval_start, - interval_end=inputs.data_interval_end, + full_range=full_range, + done_ranges=done_ranges, exclude_events=inputs.exclude_events, include_events=inputs.include_events, fields=fields, @@ -490,7 +496,7 @@ async def flush_to_bigquery( records_since_last_flush: int, bytes_since_last_flush: int, flush_counter: int, - last_inserted_at, + last_date_range, last: bool, error: Exception | None, ): @@ -508,7 +514,8 @@ async def flush_to_bigquery( rows_exported.add(records_since_last_flush) bytes_exported.add(bytes_since_last_flush) - heartbeater.details = (str(last_inserted_at),) + details.track_done_range(last_date_range, data_interval_start) + heartbeater.set_from_heartbeat_details(details) flush_tasks = [] while not queue.empty() or not produce_task.done(): @@ -535,6 +542,9 @@ async def flush_to_bigquery( await raise_on_produce_task_failure(produce_task) await logger.adebug("Successfully consumed all record batches") + details.complete_done_ranges(inputs.data_interval_end) + heartbeater.set_from_heartbeat_details(details) + records_total = functools.reduce(operator.add, (task.result() for task in flush_tasks)) if requires_merge: diff --git a/posthog/temporal/batch_exports/heartbeat.py b/posthog/temporal/batch_exports/heartbeat.py new file mode 100644 index 0000000000000..fdd21d0613eee --- /dev/null +++ b/posthog/temporal/batch_exports/heartbeat.py @@ -0,0 +1,215 @@ +import typing +import datetime as dt +import collections.abc +import dataclasses + +import structlog + +from posthog.temporal.common.heartbeat import ( + HeartbeatDetails, + HeartbeatParseError, + EmptyHeartbeatError, + NotEnoughHeartbeatValuesError, +) + +DateRange = tuple[dt.datetime, dt.datetime] + +logger = structlog.get_logger() + + +@dataclasses.dataclass +class BatchExportRangeHeartbeatDetails(HeartbeatDetails): + """Details included in every batch export heartbeat. + + Attributes: + done_ranges: Date ranges that have been successfully exported. + _remaining: Anything else in the activity details. + """ + + done_ranges: list[DateRange] = dataclasses.field(default_factory=list) + _remaining: collections.abc.Sequence[typing.Any] = dataclasses.field(default_factory=tuple) + + @classmethod + def deserialize_details(cls, details: collections.abc.Sequence[typing.Any]) -> dict[str, typing.Any]: + """Deserialize this from Temporal activity details. + + We expect done ranges to be available in the first index of remaining + values. Moreover, we expect datetime values to be ISO-formatted strings. + """ + done_ranges: list[DateRange] = [] + remaining = super().deserialize_details(details) + + if len(remaining["_remaining"]) == 0: + return {"done_ranges": done_ranges, **remaining} + + first_detail = remaining["_remaining"][0] + remaining["_remaining"] = remaining["_remaining"][1:] + + for date_str_tuple in first_detail: + try: + range_start, range_end = date_str_tuple + datetime_bounds = ( + dt.datetime.fromisoformat(range_start), + dt.datetime.fromisoformat(range_end), + ) + except (TypeError, ValueError) as e: + raise HeartbeatParseError("done_ranges") from e + + done_ranges.append(datetime_bounds) + + return {"done_ranges": done_ranges, **remaining} + + def serialize_details(self) -> tuple[typing.Any, ...]: + """Serialize this into a tuple. + + Each datetime from `self.done_ranges` must be cast to string as values must + be JSON-serializable. + """ + serialized_done_ranges = [ + (start.isoformat() if start is not None else start, end.isoformat()) for (start, end) in self.done_ranges + ] + serialized_parent_details = super().serialize_details() + return (*serialized_parent_details[:-1], serialized_done_ranges, self._remaining) + + @property + def empty(self) -> bool: + return len(self.done_ranges) == 0 + + def track_done_range( + self, done_range: DateRange, data_interval_start_input: str | dt.datetime | None, merge: bool = True + ): + """Track a range of datetime values that has been exported successfully. + + If this is the first `done_range` then we override the beginning of the + range to ensure it covers the range from `data_interval_start_input`. + + Arguments: + done_range: A date range of values that have been exported. + data_interval_start_input: The `data_interval_start` input passed to + the batch export + merge: Whether to merge the new range with existing ones. + """ + if self.empty is True: + if data_interval_start_input is None: + data_interval_start = dt.datetime.fromtimestamp(0, tz=dt.UTC) + elif isinstance(data_interval_start_input, str): + data_interval_start = dt.datetime.fromisoformat(data_interval_start_input) + else: + data_interval_start = data_interval_start_input + + done_range = (data_interval_start, done_range[1]) + + self.insert_done_range(done_range, merge=merge) + + def insert_done_range(self, done_range: DateRange, merge: bool = True): + """Insert a date range into `self.done_ranges` in order.""" + for index, range in enumerate(self.done_ranges, start=0): + if done_range[0] > range[1]: + continue + + # We have found the index where this date range should go in. + if done_range[0] == range[1]: + self.done_ranges.insert(index + 1, done_range) + else: + self.done_ranges.insert(index, done_range) + break + else: + # Date range should go at the end + self.done_ranges.append(done_range) + + if merge: + self.merge_done_ranges() + + def merge_done_ranges(self): + """Merge as many date ranges together as possible in `self.done_ranges`. + + This method looks for ranges whose opposite ends are touching and merges + them together. Notice that this method does not have enough information + to merge ranges that are not touching. + """ + marked_for_deletion = set() + for index, range in enumerate(self.done_ranges, start=0): + if index in marked_for_deletion: + continue + try: + next_range = self.done_ranges[index + 1] + except IndexError: + continue + + if next_range[0] == range[1]: + # Touching start of next range with end of range. + # End of next range set as end of existing range. + # Next range marked for deletion as it's now covered by range. + self.done_ranges[index] = (range[0], next_range[1]) + marked_for_deletion.add(index + 1) + + for index in marked_for_deletion: + self.done_ranges.pop(index) + + def complete_done_ranges(self, data_interval_end_input: str | dt.datetime): + """Complete the entire range covered by the batch export. + + This is meant to be called at the end of a batch export to ensure + `self.done_ranges` covers the entire batch period from whichever was the + first range tracked until `data_interval_end_input`. + + All ranges will be essentially merged into one (well, replaced by one) + covering everything, so it is very important to only call this once + everything is done. + """ + if isinstance(data_interval_end_input, str): + data_interval_end = dt.datetime.fromisoformat(data_interval_end_input) + else: + data_interval_end = data_interval_end_input + + self.done_ranges = [(self.done_ranges[0][0], data_interval_end)] + + +HeartbeatType = typing.TypeVar("HeartbeatType", bound=HeartbeatDetails) + + +async def should_resume_from_activity_heartbeat( + activity, heartbeat_type: type[HeartbeatType] +) -> tuple[bool, HeartbeatType | None]: + """Check if a batch export should resume from an activity's heartbeat details. + + We understand that a batch export should resume any time that we receive heartbeat details and + those details can be correctly parsed. However, the decision is ultimately up to the batch export + activity to decide if it must resume and how to do so. + + Returns: + A tuple with the first element indicating if the batch export should resume. If the first element + is True, the second tuple element will be the heartbeat details themselves, otherwise None. + """ + try: + heartbeat_details = heartbeat_type.from_activity(activity) + + except EmptyHeartbeatError: + # We don't log this as it's the expected exception when heartbeat is empty. + heartbeat_details = None + received = False + + except NotEnoughHeartbeatValuesError: + heartbeat_details = None + received = False + await logger.awarning("Details from previous activity execution did not contain the expected amount of values") + + except HeartbeatParseError: + heartbeat_details = None + received = False + await logger.awarning("Details from previous activity execution could not be parsed.") + + except Exception: + # We should start from the beginning, but we make a point to log unexpected errors. + # Ideally, any new exceptions should be added to the previous blocks after the first time and we will never land here. + heartbeat_details = None + received = False + await logger.aexception("Did not receive details from previous activity Execution due to an unexpected error") + + else: + received = True + await logger.adebug( + f"Received details from previous activity: {heartbeat_details}", + ) + + return received, heartbeat_details diff --git a/posthog/temporal/batch_exports/redshift_batch_export.py b/posthog/temporal/batch_exports/redshift_batch_export.py index 9a2ad891d2e1b..d9d634d78858c 100644 --- a/posthog/temporal/batch_exports/redshift_batch_export.py +++ b/posthog/temporal/batch_exports/redshift_batch_export.py @@ -47,7 +47,11 @@ from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.heartbeat import Heartbeater from posthog.temporal.common.logger import configure_temporal_worker_logger -from posthog.temporal.common.utils import BatchExportHeartbeatDetails, should_resume_from_activity_heartbeat +from posthog.temporal.batch_exports.heartbeat import ( + BatchExportRangeHeartbeatDetails, + DateRange, + should_resume_from_activity_heartbeat, +) def remove_escaped_whitespace_recursive(value): @@ -273,7 +277,7 @@ def get_redshift_fields_from_record_schema( @dataclasses.dataclass -class RedshiftHeartbeatDetails(BatchExportHeartbeatDetails): +class RedshiftHeartbeatDetails(BatchExportRangeHeartbeatDetails): """The Redshift batch export details included in every heartbeat.""" pass @@ -285,6 +289,9 @@ async def insert_records_to_redshift( schema: str | None, table: str, heartbeater: Heartbeater, + heartbeat_details: RedshiftHeartbeatDetails, + data_interval_start: dt.datetime | None, + data_interval_end: dt.datetime, batch_size: int = 100, use_super: bool = False, known_super_columns: list[str] | None = None, @@ -352,7 +359,11 @@ async def flush_to_redshift(batch): # the byte size of each batch the way things are currently written. We can revisit this # in the future if we decide it's useful enough. + batch_start_inserted_at = None async for record, _inserted_at in records_iterator: + if batch_start_inserted_at is None: + batch_start_inserted_at = _inserted_at + for column in columns: if known_super_columns is not None and column in known_super_columns: record[column] = json.dumps(record[column], ensure_ascii=False) @@ -362,12 +373,24 @@ async def flush_to_redshift(batch): continue await flush_to_redshift(batch) - heartbeater.details = (str(_inserted_at),) + + last_date_range = (batch_start_inserted_at, _inserted_at) + heartbeat_details.track_done_range(last_date_range, data_interval_start) + heartbeater.set_from_heartbeat_details(heartbeat_details) + + batch_start_inserted_at = None batch = [] - if len(batch) > 0: + if len(batch) > 0 and batch_start_inserted_at: await flush_to_redshift(batch) - heartbeater.details = (str(_inserted_at),) + + last_date_range = (batch_start_inserted_at, _inserted_at) + + heartbeat_details.track_done_range(last_date_range, data_interval_start) + heartbeater.set_from_heartbeat_details(heartbeat_details) + + heartbeat_details.complete_done_ranges(data_interval_end) + heartbeater.set_from_heartbeat_details(heartbeat_details) return total_rows_exported @@ -420,12 +443,11 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records if not await client.is_alive(): raise ConnectionError("Cannot establish connection to ClickHouse") - should_resume, details = await should_resume_from_activity_heartbeat(activity, RedshiftHeartbeatDetails, logger) + _, details = await should_resume_from_activity_heartbeat(activity, RedshiftHeartbeatDetails) + if details is None: + details = RedshiftHeartbeatDetails() - if should_resume is True and details is not None: - data_interval_start: str | None = details.last_inserted_at.isoformat() - else: - data_interval_start = inputs.data_interval_start + done_ranges: list[DateRange] = details.done_ranges model: BatchExportModel | BatchExportSchema | None = None if inputs.batch_export_schema is None and "batch_export_model" in { @@ -446,13 +468,19 @@ async def insert_into_redshift_activity(inputs: RedshiftInsertInputs) -> Records extra_query_parameters = model["values"] if model is not None else {} fields = model["fields"] if model is not None else None + data_interval_start = ( + dt.datetime.fromisoformat(inputs.data_interval_start) if inputs.data_interval_start else None + ) + data_interval_end = dt.datetime.fromisoformat(inputs.data_interval_end) + full_range = (data_interval_start, data_interval_end) + queue, produce_task = start_produce_batch_export_record_batches( client=client, model_name=model_name, is_backfill=inputs.is_backfill, team_id=inputs.team_id, - interval_start=data_interval_start, - interval_end=inputs.data_interval_end, + full_range=full_range, + done_ranges=done_ranges, exclude_events=inputs.exclude_events, include_events=inputs.include_events, fields=fields, @@ -545,7 +573,12 @@ def map_to_record(row: dict) -> tuple[dict, dt.datetime]: # TODO: We should be able to save a json.loads here. record[column] = remove_escaped_whitespace_recursive(json.loads(record[column])) - return record, row["_inserted_at"] + if isinstance(row["_inserted_at"], int): + inserted_at = dt.datetime.fromtimestamp(row["_inserted_at"]) + else: + inserted_at = row["_inserted_at"] + + return record, inserted_at async def record_generator() -> ( collections.abc.AsyncGenerator[tuple[dict[str, typing.Any], dt.datetime], None] @@ -574,6 +607,9 @@ async def record_generator() -> ( heartbeater=heartbeater, use_super=properties_type == "SUPER", known_super_columns=known_super_columns, + heartbeat_details=details, + data_interval_start=data_interval_start, + data_interval_end=data_interval_end, ) if requires_merge: diff --git a/posthog/temporal/batch_exports/s3_batch_export.py b/posthog/temporal/batch_exports/s3_batch_export.py index 4f88c1574d235..908f7961a548a 100644 --- a/posthog/temporal/batch_exports/s3_batch_export.py +++ b/posthog/temporal/batch_exports/s3_batch_export.py @@ -6,6 +6,7 @@ import json import posixpath import typing +import collections.abc import aioboto3 import botocore.exceptions @@ -52,6 +53,12 @@ from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.heartbeat import Heartbeater from posthog.temporal.common.logger import bind_temporal_worker_logger +from posthog.temporal.batch_exports.heartbeat import ( + BatchExportRangeHeartbeatDetails, + DateRange, + HeartbeatParseError, + should_resume_from_activity_heartbeat, +) def get_allowed_template_variables(inputs) -> dict[str, str]: @@ -379,22 +386,51 @@ async def __aexit__(self, exc_type, exc_value, traceback) -> bool: return False -class HeartbeatDetails(typing.NamedTuple): +@dataclasses.dataclass +class S3HeartbeatDetails(BatchExportRangeHeartbeatDetails): """This tuple allows us to enforce a schema on the Heartbeat details. Attributes: - last_uploaded_part_timestamp: The timestamp of the last part we managed to upload. upload_state: State to continue a S3MultiPartUpload when activity execution resumes. """ - last_uploaded_part_timestamp: str - upload_state: S3MultiPartUploadState + upload_state: S3MultiPartUploadState | None = None @classmethod - def from_activity_details(cls, details): - last_uploaded_part_timestamp = details[0] - upload_state = S3MultiPartUploadState(*details[1]) - return cls(last_uploaded_part_timestamp, upload_state) + def deserialize_details(cls, details: collections.abc.Sequence[typing.Any]) -> dict[str, typing.Any]: + """Attempt to initialize HeartbeatDetails from an activity's details.""" + upload_state = None + remaining = super().deserialize_details(details) + + if len(remaining["_remaining"]) == 0: + return {"upload_state": upload_state, **remaining} + + first_detail = remaining["_remaining"][0] + remaining["_remaining"] = remaining["_remaining"][1:] + + if first_detail is None: + return {"upload_state": None, **remaining} + + try: + upload_state = S3MultiPartUploadState(*first_detail) + except (TypeError, ValueError) as e: + raise HeartbeatParseError("upload_state") from e + + return {"upload_state": upload_state, **remaining} + + def serialize_details(self) -> tuple[typing.Any, ...]: + """Attempt to initialize HeartbeatDetails from an activity's details.""" + serialized_parent_details = super().serialize_details() + return (*serialized_parent_details[:-1], self.upload_state, self._remaining) + + def append_upload_state(self, upload_state: S3MultiPartUploadState): + if self.upload_state is None: + self.upload_state = upload_state + + current_parts = {part["PartNumber"] for part in self.upload_state.parts} + for part in upload_state.parts: + if part["PartNumber"] not in current_parts: + self.upload_state.parts.append(part) @dataclasses.dataclass @@ -428,7 +464,9 @@ class S3InsertInputs: batch_export_schema: BatchExportSchema | None = None -async def initialize_and_resume_multipart_upload(inputs: S3InsertInputs) -> tuple[S3MultiPartUpload, str | None]: +async def initialize_and_resume_multipart_upload( + inputs: S3InsertInputs, +) -> tuple[S3MultiPartUpload, S3HeartbeatDetails]: """Initialize a S3MultiPartUpload and resume it from a hearbeat state if available.""" logger = await bind_temporal_worker_logger(team_id=inputs.team_id, destination="S3") key = get_s3_key(inputs) @@ -444,34 +482,16 @@ async def initialize_and_resume_multipart_upload(inputs: S3InsertInputs) -> tupl endpoint_url=inputs.endpoint_url, ) - details = activity.info().heartbeat_details + _, details = await should_resume_from_activity_heartbeat(activity, S3HeartbeatDetails) + if details is None: + details = S3HeartbeatDetails() - try: - interval_start, upload_state = HeartbeatDetails.from_activity_details(details) - except IndexError: - # This is the error we expect when no details as the sequence will be empty. - interval_start = inputs.data_interval_start - await logger.adebug( - "Did not receive details from previous activity Execution. Export will start from the beginning %s", - interval_start, - ) - except Exception: - # We still start from the beginning, but we make a point to log unexpected errors. - # Ideally, any new exceptions should be added to the previous block after the first time and we will never land here. - interval_start = inputs.data_interval_start - await logger.awarning( - "Did not receive details from previous activity Execution due to an unexpected error. Export will start from the beginning %s", - interval_start, - ) - else: - await logger.ainfo( - "Received details from previous activity. Export will attempt to resume from %s", - interval_start, - ) - s3_upload.continue_from_state(upload_state) + if details.upload_state: + s3_upload.continue_from_state(details.upload_state) if inputs.compression == "brotli": - # Even if we receive details we cannot resume a brotli compressed upload as we have lost the compressor state. + # Even if we receive details we cannot resume a brotli compressed upload as + # we have lost the compressor state. interval_start = inputs.data_interval_start await logger.ainfo( @@ -480,7 +500,7 @@ async def initialize_and_resume_multipart_upload(inputs: S3InsertInputs) -> tupl ) await s3_upload.abort() - return s3_upload, interval_start + return s3_upload, details def s3_default_fields() -> list[BatchExportField]: @@ -527,7 +547,14 @@ async def insert_into_s3_activity(inputs: S3InsertInputs) -> RecordsCompleted: if not await client.is_alive(): raise ConnectionError("Cannot establish connection to ClickHouse") - s3_upload, interval_start = await initialize_and_resume_multipart_upload(inputs) + s3_upload, details = await initialize_and_resume_multipart_upload(inputs) + + # TODO: Switch to single-producer multiple consumer + done_ranges: list[DateRange] = details.done_ranges + if done_ranges: + data_interval_start: str | None = done_ranges[-1][1].isoformat() + else: + data_interval_start = inputs.data_interval_start model: BatchExportModel | BatchExportSchema | None = None if inputs.batch_export_schema is None and "batch_export_model" in { @@ -541,7 +568,7 @@ async def insert_into_s3_activity(inputs: S3InsertInputs) -> RecordsCompleted: model=model, client=client, team_id=inputs.team_id, - interval_start=interval_start, + interval_start=data_interval_start, interval_end=inputs.data_interval_end, exclude_events=inputs.exclude_events, include_events=inputs.include_events, @@ -562,13 +589,13 @@ async def flush_to_s3( records_since_last_flush: int, bytes_since_last_flush: int, flush_counter: int, - last_inserted_at: dt.datetime, + last_date_range: DateRange, last: bool, error: Exception | None, ): if error is not None: await logger.adebug("Error while writing part %d", s3_upload.part_number + 1, exc_info=error) - await logger.awarn( + await logger.awarning( "An error was detected while writing part %d. Partial part will not be uploaded in case it can be retried.", s3_upload.part_number + 1, ) @@ -587,7 +614,9 @@ async def flush_to_s3( rows_exported.add(records_since_last_flush) bytes_exported.add(bytes_since_last_flush) - heartbeater.details = (str(last_inserted_at), s3_upload.to_state()) + details.track_done_range(last_date_range, data_interval_start) + details.append_upload_state(s3_upload.to_state()) + heartbeater.set_from_heartbeat_details(details) first_record_batch = cast_record_batch_json_columns(first_record_batch) column_names = first_record_batch.column_names @@ -618,6 +647,9 @@ async def flush_to_s3( await writer.write_record_batch(record_batch) + details.complete_done_ranges(inputs.data_interval_end) + heartbeater.set_from_heartbeat_details(details) + records_completed = writer.records_total await s3_upload.complete() diff --git a/posthog/temporal/batch_exports/snowflake_batch_export.py b/posthog/temporal/batch_exports/snowflake_batch_export.py index d3f7b1e3a023f..8887fcd52a317 100644 --- a/posthog/temporal/batch_exports/snowflake_batch_export.py +++ b/posthog/temporal/batch_exports/snowflake_batch_export.py @@ -51,10 +51,10 @@ from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.heartbeat import Heartbeater from posthog.temporal.common.logger import bind_temporal_worker_logger -from posthog.temporal.common.utils import ( - BatchExportHeartbeatDetails, +from posthog.temporal.batch_exports.heartbeat import ( + BatchExportRangeHeartbeatDetails, + DateRange, HeartbeatParseError, - NotEnoughHeartbeatValuesError, should_resume_from_activity_heartbeat, ) @@ -90,28 +90,38 @@ class SnowflakeRetryableConnectionError(Exception): @dataclasses.dataclass -class SnowflakeHeartbeatDetails(BatchExportHeartbeatDetails): +class SnowflakeHeartbeatDetails(BatchExportRangeHeartbeatDetails): """The Snowflake batch export details included in every heartbeat. Attributes: file_no: The file number of the last file we managed to upload. """ - file_no: int + file_no: int = 0 @classmethod - def from_activity(cls, activity): - details = BatchExportHeartbeatDetails.from_activity(activity) + def deserialize_details(cls, details: collections.abc.Sequence[typing.Any]) -> dict[str, typing.Any]: + """Attempt to initialize HeartbeatDetails from an activity's details.""" + file_no = 0 + remaining = super().deserialize_details(details) - if details.total_details < 2: - raise NotEnoughHeartbeatValuesError(details.total_details, 2) + if len(remaining["_remaining"]) == 0: + return {"file_no": 0, **remaining} + + first_detail = remaining["_remaining"][0] + remaining["_remaining"] = remaining["_remaining"][1:] try: - file_no = int(details._remaining[0]) + file_no = int(first_detail) except (TypeError, ValueError) as e: raise HeartbeatParseError("file_no") from e - return cls(last_inserted_at=details.last_inserted_at, file_no=file_no, _remaining=details._remaining[2:]) + return {"file_no": file_no, **remaining} + + def serialize_details(self) -> tuple[typing.Any, ...]: + """Attempt to initialize HeartbeatDetails from an activity's details.""" + serialized_parent_details = super().serialize_details() + return (*serialized_parent_details[:-1], self.file_no, self._remaining) @dataclasses.dataclass @@ -579,16 +589,17 @@ async def insert_into_snowflake_activity(inputs: SnowflakeInsertInputs) -> Recor if not await client.is_alive(): raise ConnectionError("Cannot establish connection to ClickHouse") - should_resume, details = await should_resume_from_activity_heartbeat( - activity, SnowflakeHeartbeatDetails, logger - ) + _, details = await should_resume_from_activity_heartbeat(activity, SnowflakeHeartbeatDetails) + if details is None: + details = SnowflakeHeartbeatDetails() - if should_resume is True and details is not None: - data_interval_start: str | None = details.last_inserted_at.isoformat() - current_flush_counter = details.file_no + done_ranges: list[DateRange] = details.done_ranges + if done_ranges: + data_interval_start: str | None = done_ranges[-1][1].isoformat() else: data_interval_start = inputs.data_interval_start - current_flush_counter = 0 + + current_flush_counter = details.file_no rows_exported = get_rows_exported_metric() bytes_exported = get_bytes_exported_metric() @@ -670,7 +681,7 @@ async def flush_to_snowflake( records_since_last_flush, bytes_since_last_flush, flush_counter: int, - last_inserted_at, + last_date_range: DateRange, last: bool, error: Exception | None, ): @@ -690,7 +701,9 @@ async def flush_to_snowflake( rows_exported.add(records_since_last_flush) bytes_exported.add(bytes_since_last_flush) - heartbeater.details = (str(last_inserted_at), flush_counter) + details.track_done_range(last_date_range, data_interval_start) + details.file_no = flush_counter + heartbeater.set_from_heartbeat_details(details) writer = JSONLBatchExportWriter( max_bytes=settings.BATCH_EXPORT_SNOWFLAKE_UPLOAD_CHUNK_SIZE_BYTES, @@ -703,6 +716,9 @@ async def flush_to_snowflake( await writer.write_record_batch(record_batch) + details.complete_done_ranges(inputs.data_interval_end) + heartbeater.set_from_heartbeat_details(details) + await snow_client.copy_loaded_files_to_snowflake_table( snow_stage_table if requires_merge else snow_table, data_interval_end_str ) diff --git a/posthog/temporal/batch_exports/temporary_file.py b/posthog/temporal/batch_exports/temporary_file.py index d26db8b976171..54beae9f9b1d5 100644 --- a/posthog/temporal/batch_exports/temporary_file.py +++ b/posthog/temporal/batch_exports/temporary_file.py @@ -17,6 +17,8 @@ import pyarrow.parquet as pq import structlog +from posthog.temporal.batch_exports.heartbeat import DateRange + logger = structlog.get_logger() @@ -247,7 +249,6 @@ def reset(self): self.records_since_last_reset = 0 -LastInsertedAt = dt.datetime IsLast = bool RecordsSinceLastFlush = int BytesSinceLastFlush = int @@ -258,7 +259,7 @@ def reset(self): RecordsSinceLastFlush, BytesSinceLastFlush, FlushCounter, - LastInsertedAt, + DateRange, IsLast, Exception | None, ], @@ -318,7 +319,9 @@ def __init__( def reset_writer_tracking(self): """Reset this writer's tracking state.""" - self.last_inserted_at: dt.datetime | None = None + self.start_at_since_last_flush: dt.datetime | None = None + self.end_at_since_last_flush: dt.datetime | None = None + self.flushed_date_ranges: list[DateRange] = [] self.records_total = 0 self.records_since_last_flush = 0 self.bytes_total = 0 @@ -326,6 +329,13 @@ def reset_writer_tracking(self): self.flush_counter = 0 self.error = None + @property + def date_range_since_last_flush(self) -> DateRange | None: + if self.start_at_since_last_flush is not None and self.end_at_since_last_flush is not None: + return (self.start_at_since_last_flush, self.end_at_since_last_flush) + else: + return None + @contextlib.asynccontextmanager async def open_temporary_file(self, current_flush_counter: int = 0): """Explicitly open the temporary file this writer is writing to. @@ -352,12 +362,12 @@ async def open_temporary_file(self, current_flush_counter: int = 0): finally: self.track_bytes_written(temp_file) - if self.last_inserted_at is not None and self.bytes_since_last_flush > 0: + if self.bytes_since_last_flush > 0: # `bytes_since_last_flush` should be 0 unless: # 1. The last batch wasn't flushed as it didn't reach `max_bytes`. # 2. The last batch was flushed but there was another write after the last call to # `write_record_batch`. For example, footer bytes. - await self.flush(self.last_inserted_at, is_last=True) + await self.flush(is_last=True) self._batch_export_file = None @@ -394,24 +404,38 @@ def track_bytes_written(self, batch_export_file: BatchExportTemporaryFile) -> No async def write_record_batch(self, record_batch: pa.RecordBatch, flush: bool = True) -> None: """Issue a record batch write tracking progress and flushing if required.""" record_batch = record_batch.sort_by("_inserted_at") - last_inserted_at = record_batch.column("_inserted_at")[-1].as_py() + + if self.start_at_since_last_flush is None: + raw_start_at = record_batch.column("_inserted_at")[0].as_py() + if isinstance(raw_start_at, int): + try: + self.start_at_since_last_flush = dt.datetime.fromtimestamp(raw_start_at, tz=dt.UTC) + except Exception: + raise + else: + self.start_at_since_last_flush = raw_start_at + + raw_end_at = record_batch.column("_inserted_at")[-1].as_py() + if isinstance(raw_end_at, int): + self.end_at_since_last_flush = dt.datetime.fromtimestamp(raw_end_at, tz=dt.UTC) + else: + self.end_at_since_last_flush = raw_end_at column_names = record_batch.column_names column_names.pop(column_names.index("_inserted_at")) await asyncio.to_thread(self._write_record_batch, record_batch.select(column_names)) - self.last_inserted_at = last_inserted_at self.track_records_written(record_batch) self.track_bytes_written(self.batch_export_file) if flush and self.should_flush(): - await self.flush(last_inserted_at) + await self.flush() def should_flush(self) -> bool: return self.bytes_since_last_flush >= self.max_bytes - async def flush(self, last_inserted_at: dt.datetime, is_last: bool = False) -> None: + async def flush(self, is_last: bool = False) -> None: """Call the provided `flush_callable` and reset underlying file. The underlying batch export temporary file will be reset after calling `flush_callable`. @@ -421,12 +445,15 @@ async def flush(self, last_inserted_at: dt.datetime, is_last: bool = False) -> N self.batch_export_file.seek(0) + if self.date_range_since_last_flush is not None: + self.flushed_date_ranges.append(self.date_range_since_last_flush) + await self.flush_callable( self.batch_export_file, self.records_since_last_flush, self.bytes_since_last_flush, self.flush_counter, - last_inserted_at, + self.flushed_date_ranges[-1], is_last, self.error, ) @@ -435,6 +462,8 @@ async def flush(self, last_inserted_at: dt.datetime, is_last: bool = False) -> N self.records_since_last_flush = 0 self.bytes_since_last_flush = 0 self.flush_counter += 1 + self.start_at_since_last_flush = None + self.end_at_since_last_flush = None class JSONLBatchExportWriter(BatchExportWriter): diff --git a/posthog/temporal/common/heartbeat.py b/posthog/temporal/common/heartbeat.py index cb0d82fa23baf..d6f9463e16df8 100644 --- a/posthog/temporal/common/heartbeat.py +++ b/posthog/temporal/common/heartbeat.py @@ -1,5 +1,8 @@ import asyncio import typing +import dataclasses +import collections.abc +import abc from temporalio import activity @@ -20,7 +23,7 @@ class Heartbeater: maintained while in the context manager to avoid garbage collection. """ - def __init__(self, details: tuple[typing.Any, ...] = (), factor: int = 12): + def __init__(self, details: tuple[typing.Any, ...] = (), factor: int = 120): self._details: tuple[typing.Any, ...] = details self.factor = factor self.heartbeat_task: asyncio.Task | None = None @@ -36,6 +39,10 @@ def details(self, details: tuple[typing.Any, ...]) -> None: """Set tuple to be passed as heartbeat details.""" self._details = details + def set_from_heartbeat_details(self, details: "HeartbeatDetails") -> None: + """Set `HeartbeatDetails` to be passed as heartbeat details.""" + self._details = tuple(details.serialize_details()) + async def __aenter__(self): """Enter managed heartbeatting context.""" @@ -82,3 +89,116 @@ async def __aexit__(self, *args, **kwargs): self.heartbeat_task = None self.heartbeat_on_shutdown_task = None + + +class EmptyHeartbeatError(Exception): + """Raised when an activity heartbeat is empty. + + This is also the error we expect when no heartbeatting is happening, as the sequence will be empty. + """ + + def __init__(self): + super().__init__(f"Heartbeat details sequence is empty") + + +class NotEnoughHeartbeatValuesError(Exception): + """Raised when an activity heartbeat doesn't contain the right amount of values we expect.""" + + def __init__(self, details_len: int, expected: int): + super().__init__(f"Not enough values in heartbeat details (expected {expected}, got {details_len})") + + +class HeartbeatParseError(Exception): + """Raised when an activity heartbeat cannot be parsed into it's expected types.""" + + def __init__(self, field: str): + super().__init__(f"Parsing {field} from heartbeat details encountered an error") + + +@dataclasses.dataclass +class HeartbeatDetails(metaclass=abc.ABCMeta): + """Details included in every heartbeat. + + If an activity requires tracking progress, this should be subclassed to include + the attributes that are required for said activity. The main methods to implement + when subclassing are `deserialize_details` and `serialize_details`. Both should + deserialize from and serialize to a generic sequence or tuple, respectively. + + Attributes: + _remaining: Any remaining values in the heartbeat_details tuple that we do + not parse. + """ + + _remaining: collections.abc.Sequence[typing.Any] + + @property + def total_details(self) -> int: + """The total number of details that we have parsed + those remaining to parse.""" + return (len(dataclasses.fields(self.__class__)) - 1) + len(self._remaining) + + @classmethod + @abc.abstractmethod + def deserialize_details(cls, details: collections.abc.Sequence[typing.Any]) -> dict[str, typing.Any]: + """Deserialize `HeartbeatDetails` from a generic sequence of details. + + This base class implementation just returns all details as `_remaining`. + Subclasses first call this method, and then peek into `_remaining` and + extract the details they need. For now, subclasses can only rely on the + order in which details are serialized but in the future we may need a + more robust way of identifying details. + + Arguments: + details: A collection of details as returned by + `temporalio.activity.info().heartbeat_details` + """ + return {"_remaining": details} + + @abc.abstractmethod + def serialize_details(self) -> tuple[typing.Any, ...]: + """Serialize `HeartbeatDetails` to a tuple. + + Since subclasses rely on the order details are serialized, subclasses + should be careful here to maintain a consistent serialization order. For + example, `_remaining` should always be placed last. + + Returns: + A tuple of serialized details. + """ + return (self._remaining,) + + @classmethod + def from_activity(cls, activity): + """Instantiate this class from a Temporal Activity.""" + details = activity.info().heartbeat_details + return cls.from_activity_details(details) + + @classmethod + def from_activity_details(cls, details): + parsed = cls.deserialize_details(details) + return cls(**parsed) + + +@dataclasses.dataclass +class DataImportHeartbeatDetails(HeartbeatDetails): + """Data import heartbeat details. + + Attributes: + endpoint: The endpoint we are importing data from. + cursor: The cursor we are using to paginate through the endpoint. + """ + + endpoint: str + cursor: str + + @classmethod + def from_activity(cls, activity): + """Attempt to initialize DataImportHeartbeatDetails from an activity's info.""" + details = activity.info().heartbeat_details + + if len(details) == 0: + raise EmptyHeartbeatError() + + if len(details) != 2: + raise NotEnoughHeartbeatValuesError(len(details), 2) + + return cls(endpoint=details[0], cursor=details[1], _remaining=details[2:]) diff --git a/posthog/temporal/common/utils.py b/posthog/temporal/common/utils.py deleted file mode 100644 index fc8a77cadea81..0000000000000 --- a/posthog/temporal/common/utils.py +++ /dev/null @@ -1,149 +0,0 @@ -import abc -import collections.abc -import dataclasses -import datetime as dt -import typing - - -class EmptyHeartbeatError(Exception): - """Raised when an activity heartbeat is empty. - - This is also the error we expect when no heartbeatting is happening, as the sequence will be empty. - """ - - def __init__(self): - super().__init__(f"Heartbeat details sequence is empty") - - -class NotEnoughHeartbeatValuesError(Exception): - """Raised when an activity heartbeat doesn't contain the right amount of values we expect.""" - - def __init__(self, details_len: int, expected: int): - super().__init__(f"Not enough values in heartbeat details (expected {expected}, got {details_len})") - - -class HeartbeatParseError(Exception): - """Raised when an activity heartbeat cannot be parsed into it's expected types.""" - - def __init__(self, field: str): - super().__init__(f"Parsing {field} from heartbeat details encountered an error") - - -@dataclasses.dataclass -class HeartbeatDetails(metaclass=abc.ABCMeta): - """The batch export details included in every heartbeat. - - Each batch export destination should subclass this and implement whatever details are specific to that - batch export and required to resume it. - - Attributes: - last_inserted_at: The last inserted_at we managed to upload or insert, depending on the destination. - _remaining: Any remaining values in the heartbeat_details tuple that we do not parse. - """ - - _remaining: collections.abc.Sequence[typing.Any] - - @property - def total_details(self) -> int: - """The total number of details that we have parsed + those remaining to parse.""" - return (len(dataclasses.fields(self.__class__)) - 1) + len(self._remaining) - - @abc.abstractclassmethod - def from_activity(cls, activity): - pass - - -@dataclasses.dataclass -class BatchExportHeartbeatDetails(HeartbeatDetails): - last_inserted_at: dt.datetime - - @classmethod - def from_activity(cls, activity): - """Attempt to initialize HeartbeatDetails from an activity's info.""" - details = activity.info().heartbeat_details - - if len(details) == 0: - raise EmptyHeartbeatError() - - try: - last_inserted_at = dt.datetime.fromisoformat(details[0]) - except (TypeError, ValueError) as e: - raise HeartbeatParseError("last_inserted_at") from e - - return cls(last_inserted_at=last_inserted_at, _remaining=details[1:]) - - -@dataclasses.dataclass -class DataImportHeartbeatDetails(HeartbeatDetails): - """Data import heartbeat details. - - Attributes: - endpoint: The endpoint we are importing data from. - cursor: The cursor we are using to paginate through the endpoint. - """ - - endpoint: str - cursor: str - - @classmethod - def from_activity(cls, activity): - """Attempt to initialize DataImportHeartbeatDetails from an activity's info.""" - details = activity.info().heartbeat_details - - if len(details) == 0: - raise EmptyHeartbeatError() - - if len(details) != 2: - raise NotEnoughHeartbeatValuesError(len(details), 2) - - return cls(endpoint=details[0], cursor=details[1], _remaining=details[2:]) - - -HeartbeatType = typing.TypeVar("HeartbeatType", bound=HeartbeatDetails) - - -async def should_resume_from_activity_heartbeat( - activity, heartbeat_type: type[HeartbeatType], logger -) -> tuple[bool, HeartbeatType | None]: - """Check if a batch export should resume from an activity's heartbeat details. - - We understand that a batch export should resume any time that we receive heartbeat details and - those details can be correctly parsed. However, the decision is ultimately up to the batch export - activity to decide if it must resume and how to do so. - - Returns: - A tuple with the first element indicating if the batch export should resume. If the first element - is True, the second tuple element will be the heartbeat details themselves, otherwise None. - """ - try: - heartbeat_details = heartbeat_type.from_activity(activity) - - except EmptyHeartbeatError: - # We don't log this as it's the expected exception when heartbeat is empty. - heartbeat_details = None - received = False - - except NotEnoughHeartbeatValuesError: - heartbeat_details = None - received = False - await logger.awarning("Details from previous activity execution did not contain the expected amount of values") - - except HeartbeatParseError: - heartbeat_details = None - received = False - await logger.awarning("Details from previous activity execution could not be parsed.") - - except Exception: - # We should start from the beginning, but we make a point to log unexpected errors. - # Ideally, any new exceptions should be added to the previous blocks after the first time and we will never land here. - heartbeat_details = None - received = False - await logger.aexception("Did not receive details from previous activity Execution due to an unexpected error") - - else: - received = True - await logger.adebug( - f"Received details from previous activity: {heartbeat_details}", - ) - - return received, heartbeat_details diff --git a/posthog/temporal/tests/batch_exports/test_batch_exports.py b/posthog/temporal/tests/batch_exports/test_batch_exports.py index d365424e70bea..b8236af8322c9 100644 --- a/posthog/temporal/tests/batch_exports/test_batch_exports.py +++ b/posthog/temporal/tests/batch_exports/test_batch_exports.py @@ -13,6 +13,7 @@ RecordBatchProducerError, RecordBatchQueue, TaskNotDoneError, + generate_query_ranges, get_data_interval, iter_model_records, iter_records, @@ -463,8 +464,8 @@ async def test_start_produce_batch_export_record_batches_uses_extra_query_parame team_id=team_id, is_backfill=False, model_name="events", - interval_start=data_interval_start.isoformat(), - interval_end=data_interval_end.isoformat(), + full_range=(data_interval_start, data_interval_end), + done_ranges=[], fields=[ {"expression": "JSONExtractInt(properties, %(hogql_val_0)s)", "alias": "custom_prop"}, ], @@ -503,8 +504,8 @@ async def test_start_produce_batch_export_record_batches_can_flatten_properties( team_id=team_id, is_backfill=False, model_name="events", - interval_start=data_interval_start.isoformat(), - interval_end=data_interval_end.isoformat(), + full_range=(data_interval_start, data_interval_end), + done_ranges=[], fields=[ {"expression": "event", "alias": "event"}, {"expression": "JSONExtractString(properties, '$browser')", "alias": "browser"}, @@ -560,8 +561,8 @@ async def test_start_produce_batch_export_record_batches_with_single_field_and_a team_id=team_id, is_backfill=False, model_name="events", - interval_start=data_interval_start.isoformat(), - interval_end=data_interval_end.isoformat(), + full_range=(data_interval_start, data_interval_end), + done_ranges=[], fields=[field], extra_query_parameters={}, ) @@ -615,8 +616,8 @@ async def test_start_produce_batch_export_record_batches_ignores_timestamp_predi team_id=team_id, is_backfill=False, model_name="events", - interval_start=inserted_at.isoformat(), - interval_end=data_interval_end.isoformat(), + full_range=(inserted_at, data_interval_end), + done_ranges=[], ) records = await get_all_record_batches_from_queue(queue, produce_task) @@ -629,8 +630,8 @@ async def test_start_produce_batch_export_record_batches_ignores_timestamp_predi team_id=team_id, is_backfill=False, model_name="events", - interval_start=inserted_at.isoformat(), - interval_end=data_interval_end.isoformat(), + full_range=(inserted_at, data_interval_end), + done_ranges=[], ) records = await get_all_record_batches_from_queue(queue, produce_task) @@ -664,8 +665,8 @@ async def test_start_produce_batch_export_record_batches_can_include_events(clic team_id=team_id, is_backfill=False, model_name="events", - interval_start=data_interval_start.isoformat(), - interval_end=data_interval_end.isoformat(), + full_range=(data_interval_start, data_interval_end), + done_ranges=[], include_events=include_events, ) @@ -700,8 +701,8 @@ async def test_start_produce_batch_export_record_batches_can_exclude_events(clic team_id=team_id, is_backfill=False, model_name="events", - interval_start=data_interval_start.isoformat(), - interval_end=data_interval_end.isoformat(), + full_range=(data_interval_start, data_interval_end), + done_ranges=[], exclude_events=exclude_events, ) @@ -733,8 +734,8 @@ async def test_start_produce_batch_export_record_batches_handles_duplicates(clic team_id=team_id, is_backfill=False, model_name="events", - interval_start=data_interval_start.isoformat(), - interval_end=data_interval_end.isoformat(), + full_range=(data_interval_start, data_interval_end), + done_ranges=[], ) records = await get_all_record_batches_from_queue(queue, produce_task) @@ -833,3 +834,119 @@ async def fake_produce_task(): await asyncio.wait([task]) await raise_on_produce_task_failure(task) + + +@pytest.mark.parametrize( + "remaining_range,done_ranges,expected", + [ + # Case 1: One done range at the beginning + ( + (dt.datetime(2023, 7, 31, 12, 0, 0, tzinfo=dt.UTC), dt.datetime(2023, 7, 31, 13, 0, 0, tzinfo=dt.UTC)), + [(dt.datetime(2023, 7, 31, 12, 0, 0, tzinfo=dt.UTC), dt.datetime(2023, 7, 31, 12, 30, 0, tzinfo=dt.UTC))], + [ + ( + dt.datetime(2023, 7, 31, 12, 30, 0, tzinfo=dt.UTC), + dt.datetime(2023, 7, 31, 13, 0, 0, tzinfo=dt.UTC), + ) + ], + ), + # Case 2: Single done range equal to full range. + ( + (dt.datetime(2023, 7, 31, 12, 0, 0, tzinfo=dt.UTC), dt.datetime(2023, 7, 31, 13, 0, 0, tzinfo=dt.UTC)), + [(dt.datetime(2023, 7, 31, 12, 0, 0, tzinfo=dt.UTC), dt.datetime(2023, 7, 31, 13, 0, 0, tzinfo=dt.UTC))], + [], + ), + # Case 3: Disconnected done ranges cover full range. + ( + (dt.datetime(2023, 7, 31, 12, 0, 0, tzinfo=dt.UTC), dt.datetime(2023, 7, 31, 13, 0, 0, tzinfo=dt.UTC)), + [ + (dt.datetime(2023, 7, 31, 12, 0, 0, tzinfo=dt.UTC), dt.datetime(2023, 7, 31, 12, 30, 0, tzinfo=dt.UTC)), + ( + dt.datetime(2023, 7, 31, 12, 30, 0, tzinfo=dt.UTC), + dt.datetime(2023, 7, 31, 12, 45, 0, tzinfo=dt.UTC), + ), + ( + dt.datetime(2023, 7, 31, 12, 45, 0, tzinfo=dt.UTC), + dt.datetime(2023, 7, 31, 13, 0, 0, tzinfo=dt.UTC), + ), + ], + [], + ), + # Case 4: Disconnect done ranges within full range. + ( + (dt.datetime(2023, 7, 31, 12, 0, 0, tzinfo=dt.UTC), dt.datetime(2023, 7, 31, 13, 0, 0, tzinfo=dt.UTC)), + [ + ( + dt.datetime(2023, 7, 31, 12, 30, 0, tzinfo=dt.UTC), + dt.datetime(2023, 7, 31, 12, 45, 0, tzinfo=dt.UTC), + ), + ( + dt.datetime(2023, 7, 31, 12, 50, 0, tzinfo=dt.UTC), + dt.datetime(2023, 7, 31, 12, 55, 0, tzinfo=dt.UTC), + ), + ], + [ + ( + dt.datetime(2023, 7, 31, 12, 0, 0, tzinfo=dt.UTC), + dt.datetime(2023, 7, 31, 12, 30, 0, tzinfo=dt.UTC), + ), + ( + dt.datetime(2023, 7, 31, 12, 45, 0, tzinfo=dt.UTC), + dt.datetime(2023, 7, 31, 12, 50, 0, tzinfo=dt.UTC), + ), + ( + dt.datetime(2023, 7, 31, 12, 55, 0, tzinfo=dt.UTC), + dt.datetime(2023, 7, 31, 13, 0, 0, tzinfo=dt.UTC), + ), + ], + ), + # Case 5: Empty done ranges. + ( + (dt.datetime(2023, 7, 31, 12, 0, 0, tzinfo=dt.UTC), dt.datetime(2023, 7, 31, 13, 0, 0, tzinfo=dt.UTC)), + [], + [ + ( + dt.datetime(2023, 7, 31, 12, 0, 0, tzinfo=dt.UTC), + dt.datetime(2023, 7, 31, 13, 0, 0, tzinfo=dt.UTC), + ), + ], + ), + # Case 6: Disconnect done ranges within full range and one last done range connected to the end. + ( + (dt.datetime(2023, 7, 31, 12, 0, 0, tzinfo=dt.UTC), dt.datetime(2023, 7, 31, 13, 0, 0, tzinfo=dt.UTC)), + [ + ( + dt.datetime(2023, 7, 31, 12, 15, 0, tzinfo=dt.UTC), + dt.datetime(2023, 7, 31, 12, 25, 0, tzinfo=dt.UTC), + ), + ( + dt.datetime(2023, 7, 31, 12, 30, 0, tzinfo=dt.UTC), + dt.datetime(2023, 7, 31, 12, 45, 0, tzinfo=dt.UTC), + ), + ( + dt.datetime(2023, 7, 31, 12, 50, 0, tzinfo=dt.UTC), + dt.datetime(2023, 7, 31, 13, 0, 0, tzinfo=dt.UTC), + ), + ], + [ + ( + dt.datetime(2023, 7, 31, 12, 0, 0, tzinfo=dt.UTC), + dt.datetime(2023, 7, 31, 12, 15, 0, tzinfo=dt.UTC), + ), + ( + dt.datetime(2023, 7, 31, 12, 25, 0, tzinfo=dt.UTC), + dt.datetime(2023, 7, 31, 12, 30, 0, tzinfo=dt.UTC), + ), + ( + dt.datetime(2023, 7, 31, 12, 45, 0, tzinfo=dt.UTC), + dt.datetime(2023, 7, 31, 12, 50, 0, tzinfo=dt.UTC), + ), + ], + ), + ], + ids=["1", "2", "3", "4", "5", "6"], +) +def test_generate_query_ranges(remaining_range, done_ranges, expected): + """Test get_data_interval returns the expected data interval tuple.""" + result = list(generate_query_ranges(remaining_range, done_ranges)) + assert result == expected diff --git a/posthog/temporal/tests/batch_exports/test_bigquery_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_bigquery_batch_export_workflow.py index 22c70cc0d16ee..bc76b062f4d64 100644 --- a/posthog/temporal/tests/batch_exports/test_bigquery_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_bigquery_batch_export_workflow.py @@ -27,6 +27,7 @@ ) from posthog.temporal.batch_exports.bigquery_batch_export import ( BigQueryBatchExportWorkflow, + BigQueryHeartbeatDetails, BigQueryInsertInputs, bigquery_default_fields, get_bigquery_fields_from_record_schema, @@ -54,14 +55,19 @@ TEST_TIME = dt.datetime.now(dt.UTC) +@pytest.fixture +def activity_environment(activity_environment): + activity_environment.heartbeat_class = BigQueryHeartbeatDetails + return activity_environment + + async def assert_clickhouse_records_in_bigquery( bigquery_client: bigquery.Client, clickhouse_client: ClickHouseClient, team_id: int, table_id: str, dataset_id: str, - data_interval_start: dt.datetime, - data_interval_end: dt.datetime, + date_ranges: list[tuple[dt.datetime, dt.datetime]], min_ingested_timestamp: dt.datetime | None = None, exclude_events: list[str] | None = None, include_events: list[str] | None = None, @@ -69,6 +75,7 @@ async def assert_clickhouse_records_in_bigquery( use_json_type: bool = False, sort_key: str = "event", is_backfill: bool = False, + expect_duplicates: bool = False, ) -> None: """Assert ClickHouse records are written to a given BigQuery table. @@ -78,13 +85,13 @@ async def assert_clickhouse_records_in_bigquery( team_id: The ID of the team that we are testing for. table_id: BigQuery table id where records are exported to. dataset_id: BigQuery dataset containing the table where records are exported to. - data_interval_start: Start of the batch period for exported records. - data_interval_end: End of the batch period for exported records. + date_ranges: Ranges of records we should expect to have been exported. min_ingested_timestamp: A datetime used to assert a minimum bound for 'bq_ingested_timestamp'. exclude_events: Event names to be excluded from the export. include_events: Event names to be included in the export. batch_export_schema: Custom schema used in the batch export. use_json_type: Whether to use JSON type for known fields. + expect_duplicates: Whether duplicates are expected (e.g. when testing retrying logic). """ if use_json_type is True: json_columns = ["properties", "set", "set_once", "person_properties"] @@ -135,34 +142,49 @@ async def assert_clickhouse_records_in_bigquery( ] expected_records = [] - async for record_batch in iter_model_records( - client=clickhouse_client, - model=batch_export_model, - team_id=team_id, - interval_start=data_interval_start.isoformat(), - interval_end=data_interval_end.isoformat(), - exclude_events=exclude_events, - include_events=include_events, - destination_default_fields=bigquery_default_fields(), - is_backfill=is_backfill, - ): - for record in record_batch.select(schema_column_names).to_pylist(): - expected_record = {} - - for k, v in record.items(): - if k not in schema_column_names or k == "_inserted_at" or k == "bq_ingested_timestamp": - # _inserted_at is not exported, only used for tracking progress. - # bq_ingested_timestamp cannot be compared as it comes from an unstable function. - continue - - if k in json_columns and v is not None: - expected_record[k] = json.loads(v) - elif isinstance(v, dt.datetime): - expected_record[k] = v.replace(tzinfo=dt.UTC) - else: - expected_record[k] = v - - expected_records.append(expected_record) + for data_interval_start, data_interval_end in date_ranges: + async for record_batch in iter_model_records( + client=clickhouse_client, + model=batch_export_model, + team_id=team_id, + interval_start=data_interval_start.isoformat(), + interval_end=data_interval_end.isoformat(), + exclude_events=exclude_events, + include_events=include_events, + destination_default_fields=bigquery_default_fields(), + is_backfill=is_backfill, + ): + for record in record_batch.select(schema_column_names).to_pylist(): + expected_record = {} + + for k, v in record.items(): + if k not in schema_column_names or k == "_inserted_at" or k == "bq_ingested_timestamp": + # _inserted_at is not exported, only used for tracking progress. + # bq_ingested_timestamp cannot be compared as it comes from an unstable function. + continue + + if k in json_columns and v is not None: + expected_record[k] = json.loads(v) + elif isinstance(v, dt.datetime): + expected_record[k] = v.replace(tzinfo=dt.UTC) + else: + expected_record[k] = v + + expected_records.append(expected_record) + + if expect_duplicates: + seen = set() + + def is_record_seen(record) -> bool: + nonlocal seen + + if record["uuid"] in seen: + return True + + seen.add(record["uuid"]) + return False + + inserted_records = [record for record in inserted_records if not is_record_seen(record)] assert len(inserted_records) == len(expected_records) @@ -328,8 +350,7 @@ async def test_insert_into_bigquery_activity_inserts_data_into_bigquery_table( table_id=f"test_insert_activity_table_{ateam.pk}", dataset_id=bigquery_dataset.dataset_id, team_id=ateam.pk, - data_interval_start=data_interval_start, - data_interval_end=data_interval_end, + date_ranges=[(data_interval_start, data_interval_end)], exclude_events=exclude_events, include_events=None, batch_export_model=model, @@ -382,8 +403,7 @@ async def test_insert_into_bigquery_activity_merges_data_in_follow_up_runs( table_id=f"test_insert_activity_mutability_table_{ateam.pk}", dataset_id=bigquery_dataset.dataset_id, team_id=ateam.pk, - data_interval_start=data_interval_start, - data_interval_end=data_interval_end, + date_ranges=[(data_interval_start, data_interval_end)], batch_export_model=model, min_ingested_timestamp=ingested_timestamp, sort_key="person_id", @@ -423,14 +443,235 @@ async def test_insert_into_bigquery_activity_merges_data_in_follow_up_runs( table_id=f"test_insert_activity_mutability_table_{ateam.pk}", dataset_id=bigquery_dataset.dataset_id, team_id=ateam.pk, - data_interval_start=data_interval_start, - data_interval_end=data_interval_end, + date_ranges=[(data_interval_start, data_interval_end)], batch_export_model=model, min_ingested_timestamp=ingested_timestamp, sort_key="person_id", ) +@pytest.mark.parametrize("interval", ["hour"], indirect=True) +@pytest.mark.parametrize( + "done_relative_ranges,expected_relative_ranges", + [ + ( + [(dt.timedelta(minutes=0), dt.timedelta(minutes=15))], + [(dt.timedelta(minutes=15), dt.timedelta(minutes=60))], + ), + ( + [ + (dt.timedelta(minutes=10), dt.timedelta(minutes=15)), + (dt.timedelta(minutes=35), dt.timedelta(minutes=45)), + ], + [ + (dt.timedelta(minutes=0), dt.timedelta(minutes=10)), + (dt.timedelta(minutes=15), dt.timedelta(minutes=35)), + (dt.timedelta(minutes=45), dt.timedelta(minutes=60)), + ], + ), + ( + [ + (dt.timedelta(minutes=45), dt.timedelta(minutes=60)), + ], + [ + (dt.timedelta(minutes=0), dt.timedelta(minutes=45)), + ], + ), + ], +) +async def test_insert_into_bigquery_activity_resumes_from_heartbeat( + clickhouse_client, + activity_environment, + bigquery_client, + bigquery_config, + bigquery_dataset, + generate_test_data, + data_interval_start, + data_interval_end, + ateam, + done_relative_ranges, + expected_relative_ranges, +): + """Test we insert partial data into a BigQuery table when resuming. + + After an activity runs, heartbeats, and crashes, a follow-up activity should + pick-up from where the first one left. This capability is critical to ensure + long-running activities that export a lot of data will eventually finish. + """ + batch_export_model = BatchExportModel(name="events", schema=None) + + insert_inputs = BigQueryInsertInputs( + team_id=ateam.pk, + table_id=f"test_insert_activity_table_{ateam.pk}", + dataset_id=bigquery_dataset.dataset_id, + data_interval_start=data_interval_start.isoformat(), + data_interval_end=data_interval_end.isoformat(), + use_json_type=True, + batch_export_model=batch_export_model, + **bigquery_config, + ) + + now = dt.datetime.now(tz=dt.UTC) + done_ranges = [ + ( + (data_interval_start + done_relative_range[0]).isoformat(), + (data_interval_start + done_relative_range[1]).isoformat(), + ) + for done_relative_range in done_relative_ranges + ] + expected_ranges = [ + ( + (data_interval_start + expected_relative_range[0]), + (data_interval_start + expected_relative_range[1]), + ) + for expected_relative_range in expected_relative_ranges + ] + workflow_id = uuid.uuid4() + + fake_info = activity.Info( + activity_id="insert-into-bigquery-activity", + activity_type="unknown", + current_attempt_scheduled_time=dt.datetime.now(dt.UTC), + workflow_id=str(workflow_id), + workflow_type="bigquery-export", + workflow_run_id=str(uuid.uuid4()), + attempt=1, + heartbeat_timeout=dt.timedelta(seconds=1), + heartbeat_details=[done_ranges], + is_local=False, + schedule_to_close_timeout=dt.timedelta(seconds=10), + scheduled_time=dt.datetime.now(dt.UTC), + start_to_close_timeout=dt.timedelta(seconds=20), + started_time=dt.datetime.now(dt.UTC), + task_queue="test", + task_token=b"test", + workflow_namespace="default", + ) + + activity_environment.info = fake_info + await activity_environment.run(insert_into_bigquery_activity, insert_inputs) + + await assert_clickhouse_records_in_bigquery( + bigquery_client=bigquery_client, + clickhouse_client=clickhouse_client, + table_id=f"test_insert_activity_table_{ateam.pk}", + dataset_id=bigquery_dataset.dataset_id, + team_id=ateam.pk, + date_ranges=expected_ranges, + include_events=None, + batch_export_model=batch_export_model, + use_json_type=True, + min_ingested_timestamp=now, + sort_key="event", + ) + + +async def test_insert_into_bigquery_activity_completes_range( + clickhouse_client, + activity_environment, + bigquery_client, + bigquery_config, + bigquery_dataset, + generate_test_data, + data_interval_start, + data_interval_end, + ateam, +): + """Test we complete a full range of data into a BigQuery table when resuming. + + We run two activities: + 1. First activity, up to (and including) the cutoff event. + 2. Second activity with a heartbeat detail matching the cutoff event. + + This simulates the batch export resuming from a failed execution. The full range + should be completed (with a duplicate on the cutoff event) after both activities + are done. + """ + batch_export_model = BatchExportModel(name="events", schema=None) + now = dt.datetime.now(tz=dt.UTC) + + events_to_export_created, _ = generate_test_data + events_to_export_created.sort(key=operator.itemgetter("inserted_at")) + + cutoff_event = events_to_export_created[len(events_to_export_created) // 2 : len(events_to_export_created) // 2 + 1] + assert len(cutoff_event) == 1 + cutoff_event = cutoff_event[0] + cutoff_data_interval_end = dt.datetime.fromisoformat(cutoff_event["inserted_at"]).replace(tzinfo=dt.UTC) + + insert_inputs = BigQueryInsertInputs( + team_id=ateam.pk, + table_id=f"test_insert_activity_table_{ateam.pk}", + dataset_id=bigquery_dataset.dataset_id, + data_interval_start=data_interval_start.isoformat(), + # The extra second is because the upper range select is exclusive and + # we want cutoff to be the last event included. + data_interval_end=(cutoff_data_interval_end + dt.timedelta(seconds=1)).isoformat(), + use_json_type=True, + batch_export_model=batch_export_model, + **bigquery_config, + ) + + await activity_environment.run(insert_into_bigquery_activity, insert_inputs) + + done_ranges = [ + ( + data_interval_start.isoformat(), + cutoff_data_interval_end.isoformat(), + ), + ] + workflow_id = uuid.uuid4() + + fake_info = activity.Info( + activity_id="insert-into-bigquery-activity", + activity_type="unknown", + current_attempt_scheduled_time=dt.datetime.now(dt.UTC), + workflow_id=str(workflow_id), + workflow_type="bigquery-export", + workflow_run_id=str(uuid.uuid4()), + attempt=1, + heartbeat_timeout=dt.timedelta(seconds=1), + heartbeat_details=[done_ranges], + is_local=False, + schedule_to_close_timeout=dt.timedelta(seconds=10), + scheduled_time=dt.datetime.now(dt.UTC), + start_to_close_timeout=dt.timedelta(seconds=20), + started_time=dt.datetime.now(dt.UTC), + task_queue="test", + task_token=b"test", + workflow_namespace="default", + ) + + activity_environment.info = fake_info + + insert_inputs = BigQueryInsertInputs( + team_id=ateam.pk, + table_id=f"test_insert_activity_table_{ateam.pk}", + dataset_id=bigquery_dataset.dataset_id, + data_interval_start=data_interval_start.isoformat(), + data_interval_end=data_interval_end.isoformat(), + use_json_type=True, + batch_export_model=batch_export_model, + **bigquery_config, + ) + + await activity_environment.run(insert_into_bigquery_activity, insert_inputs) + + await assert_clickhouse_records_in_bigquery( + bigquery_client=bigquery_client, + clickhouse_client=clickhouse_client, + table_id=f"test_insert_activity_table_{ateam.pk}", + dataset_id=bigquery_dataset.dataset_id, + team_id=ateam.pk, + date_ranges=[(data_interval_start, data_interval_end)], + include_events=None, + batch_export_model=batch_export_model, + use_json_type=True, + min_ingested_timestamp=now, + sort_key="event", + expect_duplicates=True, + ) + + @pytest.fixture def table_id(ateam, interval): return f"test_workflow_table_{ateam.pk}_{interval}" @@ -532,7 +773,7 @@ async def test_bigquery_export_workflow( id=workflow_id, task_queue=BATCH_EXPORTS_TASK_QUEUE, retry_policy=RetryPolicy(maximum_attempts=1), - execution_timeout=dt.timedelta(seconds=10), + execution_timeout=dt.timedelta(seconds=30), ) runs = await afetch_batch_export_runs(batch_export_id=bigquery_batch_export.id) @@ -552,8 +793,7 @@ async def test_bigquery_export_workflow( table_id=table_id, dataset_id=bigquery_batch_export.destination.config["dataset_id"], team_id=ateam.pk, - data_interval_start=data_interval_start, - data_interval_end=data_interval_end, + date_ranges=[(data_interval_start, data_interval_end)], exclude_events=exclude_events, include_events=None, batch_export_model=model, @@ -715,8 +955,7 @@ async def test_bigquery_export_workflow_backfill_earliest_persons( table_id=table_id, dataset_id=bigquery_batch_export.destination.config["dataset_id"], team_id=ateam.pk, - data_interval_start=data_interval_start, - data_interval_end=data_interval_end, + date_ranges=[(data_interval_start, data_interval_end)], batch_export_model=model, use_json_type=use_json_type, sort_key="person_id", @@ -759,6 +998,7 @@ async def insert_into_bigquery_activity_mocked(_: BigQueryInsertInputs) -> str: id=workflow_id, task_queue=BATCH_EXPORTS_TASK_QUEUE, retry_policy=RetryPolicy(maximum_attempts=1), + execution_timeout=dt.timedelta(seconds=20), ) runs = await afetch_batch_export_runs(batch_export_id=bigquery_batch_export.id) diff --git a/posthog/temporal/tests/batch_exports/test_heartbeat.py b/posthog/temporal/tests/batch_exports/test_heartbeat.py new file mode 100644 index 0000000000000..d09863befa5c3 --- /dev/null +++ b/posthog/temporal/tests/batch_exports/test_heartbeat.py @@ -0,0 +1,104 @@ +import datetime as dt + +import pytest + +from posthog.temporal.batch_exports.heartbeat import BatchExportRangeHeartbeatDetails + + +@pytest.mark.parametrize( + "initial_done_ranges,done_range,expected_index", + [ + # Case 1: Inserting into an empty initial list. + ([], (dt.datetime.fromtimestamp(5), dt.datetime.fromtimestamp(6)), 0), + # Case 2: Inserting into middle of initial list. + ( + [ + (dt.datetime.fromtimestamp(0), dt.datetime.fromtimestamp(5)), + (dt.datetime.fromtimestamp(6), dt.datetime.fromtimestamp(10)), + ], + (dt.datetime.fromtimestamp(5), dt.datetime.fromtimestamp(6)), + 1, + ), + # Case 3: Inserting into beginning of initial list. + ( + [ + (dt.datetime.fromtimestamp(1), dt.datetime.fromtimestamp(5)), + (dt.datetime.fromtimestamp(6), dt.datetime.fromtimestamp(10)), + ], + (dt.datetime.fromtimestamp(0), dt.datetime.fromtimestamp(1)), + 0, + ), + # Case 4: Inserting into end of initial list. + ( + [(dt.datetime.fromtimestamp(0), dt.datetime.fromtimestamp(10))], + (dt.datetime.fromtimestamp(10), dt.datetime.fromtimestamp(11)), + 1, + ), + # Case 5: Inserting disconnected range into middle of initial list. + ( + [ + (dt.datetime.fromtimestamp(0), dt.datetime.fromtimestamp(10)), + (dt.datetime.fromtimestamp(15), dt.datetime.fromtimestamp(20)), + ], + (dt.datetime.fromtimestamp(12), dt.datetime.fromtimestamp(13)), + 1, + ), + ], +) +def test_insert_done_range(initial_done_ranges, done_range, expected_index): + """Test `BatchExportRangeHeartbeatDetails` inserts a done range in the expected index. + + We avoid merging ranges to maintain the original index so we can assert it matches + the expected index. + """ + heartbeat_details = BatchExportRangeHeartbeatDetails() + heartbeat_details.done_ranges.extend(initial_done_ranges) + heartbeat_details.insert_done_range(done_range, merge=False) + + assert len(heartbeat_details.done_ranges) == len(initial_done_ranges) + 1 + assert heartbeat_details.done_ranges.index(done_range) == expected_index + + +@pytest.mark.parametrize( + "initial_done_ranges,expected_done_ranges", + [ + # Case 1: Disconnected ranges are not merged. + ( + [ + (dt.datetime.fromtimestamp(0), dt.datetime.fromtimestamp(5)), + (dt.datetime.fromtimestamp(6), dt.datetime.fromtimestamp(10)), + ], + [ + (dt.datetime.fromtimestamp(0), dt.datetime.fromtimestamp(5)), + (dt.datetime.fromtimestamp(6), dt.datetime.fromtimestamp(10)), + ], + ), + # Case 2: Connected ranges are merged. + ( + [ + (dt.datetime.fromtimestamp(0), dt.datetime.fromtimestamp(5)), + (dt.datetime.fromtimestamp(5), dt.datetime.fromtimestamp(10)), + ], + [(dt.datetime.fromtimestamp(0), dt.datetime.fromtimestamp(10))], + ), + # Case 3: Connected ranges are merged, but disconnected are not. + ( + [ + (dt.datetime.fromtimestamp(0), dt.datetime.fromtimestamp(5)), + (dt.datetime.fromtimestamp(5), dt.datetime.fromtimestamp(10)), + (dt.datetime.fromtimestamp(11), dt.datetime.fromtimestamp(12)), + ], + [ + (dt.datetime.fromtimestamp(0), dt.datetime.fromtimestamp(10)), + (dt.datetime.fromtimestamp(11), dt.datetime.fromtimestamp(12)), + ], + ), + ], +) +def test_merge_done_ranges(initial_done_ranges, expected_done_ranges): + """Test `BatchExportRangeHeartbeatDetails` merges done ranges.""" + heartbeat_details = BatchExportRangeHeartbeatDetails() + heartbeat_details.done_ranges.extend(initial_done_ranges) + heartbeat_details.merge_done_ranges() + + assert heartbeat_details.done_ranges == expected_done_ranges diff --git a/posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py index 2067ae65d7cae..aaf4469435508 100644 --- a/posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_redshift_batch_export_workflow.py @@ -3,7 +3,7 @@ import operator import os import warnings -from uuid import uuid4 +import uuid import psycopg import pytest @@ -59,13 +59,13 @@ async def assert_clickhouse_records_in_redshfit( table_name: str, team_id: int, batch_export_model: BatchExportModel | BatchExportSchema | None, - data_interval_start: dt.datetime, - data_interval_end: dt.datetime, + date_ranges: list[tuple[dt.datetime, dt.datetime]], exclude_events: list[str] | None = None, include_events: list[str] | None = None, properties_data_type: str = "varchar", sort_key: str = "event", is_backfill: bool = False, + expected_duplicates_threshold: float = 0.0, ): """Assert expected records are written to a given Redshift table. @@ -89,6 +89,9 @@ async def assert_clickhouse_records_in_redshfit( table_name: Redshift table name. team_id: The ID of the team that we are testing events for. batch_export_schema: Custom schema used in the batch export. + date_ranges: Ranges of records we should expect to have been exported. + expected_duplicates_threshold: Threshold of duplicates we should expect relative to + number of unique events, fail if we exceed it. """ super_columns = ["properties", "set", "set_once", "person_properties"] @@ -132,33 +135,49 @@ async def assert_clickhouse_records_in_redshfit( ] expected_records = [] - async for record_batch in iter_model_records( - client=clickhouse_client, - model=batch_export_model, - team_id=team_id, - interval_start=data_interval_start.isoformat(), - interval_end=data_interval_end.isoformat(), - exclude_events=exclude_events, - include_events=include_events, - destination_default_fields=redshift_default_fields(), - is_backfill=is_backfill, - ): - for record in record_batch.select(schema_column_names).to_pylist(): - expected_record = {} - - for k, v in record.items(): - if k not in schema_column_names or k == "_inserted_at": - # _inserted_at is not exported, only used for tracking progress. - continue - - elif k in super_columns and v is not None: - expected_record[k] = remove_escaped_whitespace_recursive(json.loads(v)) - elif isinstance(v, dt.datetime): - expected_record[k] = v.replace(tzinfo=dt.UTC) - else: - expected_record[k] = v - - expected_records.append(expected_record) + for data_interval_start, data_interval_end in date_ranges: + async for record_batch in iter_model_records( + client=clickhouse_client, + model=batch_export_model, + team_id=team_id, + interval_start=data_interval_start.isoformat(), + interval_end=data_interval_end.isoformat(), + exclude_events=exclude_events, + include_events=include_events, + destination_default_fields=redshift_default_fields(), + is_backfill=is_backfill, + ): + for record in record_batch.select(schema_column_names).to_pylist(): + expected_record = {} + + for k, v in record.items(): + if k not in schema_column_names or k == "_inserted_at": + # _inserted_at is not exported, only used for tracking progress. + continue + + elif k in super_columns and v is not None: + expected_record[k] = remove_escaped_whitespace_recursive(json.loads(v)) + elif isinstance(v, dt.datetime): + expected_record[k] = v.replace(tzinfo=dt.UTC) + else: + expected_record[k] = v + + expected_records.append(expected_record) + + if expected_duplicates_threshold > 0.0: + seen = set() + + def is_record_seen(record) -> bool: + nonlocal seen + if record["uuid"] in seen: + return True + + seen.add(record["uuid"]) + return False + + inserted_records = [record for record in inserted_records if not is_record_seen(record)] + unduplicated_len = len(inserted_records) + assert (unduplicated_len - len(inserted_records)) / len(inserted_records) < expected_duplicates_threshold inserted_column_names = list(inserted_records[0].keys()) expected_column_names = list(expected_records[0].keys()) @@ -171,6 +190,7 @@ async def assert_clickhouse_records_in_redshfit( assert inserted_column_names == expected_column_names assert inserted_records[0] == expected_records[0] assert inserted_records == expected_records + assert len(inserted_records) == len(expected_records) @pytest.fixture @@ -348,8 +368,7 @@ async def test_insert_into_redshift_activity_inserts_data_into_redshift_table( schema_name=redshift_config["schema"], table_name=table_name, team_id=ateam.pk, - data_interval_start=data_interval_start, - data_interval_end=data_interval_end, + date_ranges=[(data_interval_start, data_interval_end)], batch_export_model=model, exclude_events=exclude_events, properties_data_type=properties_data_type, @@ -357,6 +376,227 @@ async def test_insert_into_redshift_activity_inserts_data_into_redshift_table( ) +@pytest.mark.parametrize("interval", ["hour"], indirect=True) +@pytest.mark.parametrize( + "done_relative_ranges,expected_relative_ranges", + [ + ( + [(dt.timedelta(minutes=0), dt.timedelta(minutes=15))], + [(dt.timedelta(minutes=15), dt.timedelta(minutes=60))], + ), + ( + [ + (dt.timedelta(minutes=10), dt.timedelta(minutes=15)), + (dt.timedelta(minutes=35), dt.timedelta(minutes=45)), + ], + [ + (dt.timedelta(minutes=0), dt.timedelta(minutes=10)), + (dt.timedelta(minutes=15), dt.timedelta(minutes=35)), + (dt.timedelta(minutes=45), dt.timedelta(minutes=60)), + ], + ), + ( + [ + (dt.timedelta(minutes=45), dt.timedelta(minutes=60)), + ], + [ + (dt.timedelta(minutes=0), dt.timedelta(minutes=45)), + ], + ), + ], +) +async def test_insert_into_bigquery_activity_resumes_from_heartbeat( + clickhouse_client, + activity_environment, + psycopg_connection, + redshift_config, + exclude_events, + generate_test_data, + data_interval_start, + data_interval_end, + ateam, + done_relative_ranges, + expected_relative_ranges, +): + """Test we insert partial data into a BigQuery table when resuming. + + After an activity runs, heartbeats, and crashes, a follow-up activity should + pick-up from where the first one left. This capability is critical to ensure + long-running activities that export a lot of data will eventually finish. + """ + batch_export_model = BatchExportModel(name="events", schema=None) + properties_data_type = "varchar" + + insert_inputs = RedshiftInsertInputs( + team_id=ateam.pk, + table_name=f"test_insert_activity_table_{ateam.pk}", + data_interval_start=data_interval_start.isoformat(), + data_interval_end=data_interval_end.isoformat(), + exclude_events=exclude_events, + batch_export_model=batch_export_model, + properties_data_type=properties_data_type, + **redshift_config, + ) + + done_ranges = [ + ( + (data_interval_start + done_relative_range[0]).isoformat(), + (data_interval_start + done_relative_range[1]).isoformat(), + ) + for done_relative_range in done_relative_ranges + ] + expected_ranges = [ + ( + (data_interval_start + expected_relative_range[0]), + (data_interval_start + expected_relative_range[1]), + ) + for expected_relative_range in expected_relative_ranges + ] + workflow_id = uuid.uuid4() + + fake_info = activity.Info( + activity_id="insert-into-redshift-activity", + activity_type="unknown", + current_attempt_scheduled_time=dt.datetime.now(dt.UTC), + workflow_id=str(workflow_id), + workflow_type="redshift-export", + workflow_run_id=str(uuid.uuid4()), + attempt=1, + heartbeat_timeout=dt.timedelta(seconds=1), + heartbeat_details=[done_ranges], + is_local=False, + schedule_to_close_timeout=dt.timedelta(seconds=10), + scheduled_time=dt.datetime.now(dt.UTC), + start_to_close_timeout=dt.timedelta(seconds=20), + started_time=dt.datetime.now(dt.UTC), + task_queue="test", + task_token=b"test", + workflow_namespace="default", + ) + + activity_environment.info = fake_info + await activity_environment.run(insert_into_redshift_activity, insert_inputs) + + await assert_clickhouse_records_in_redshfit( + redshift_connection=psycopg_connection, + clickhouse_client=clickhouse_client, + schema_name=redshift_config["schema"], + table_name=f"test_insert_activity_table_{ateam.pk}", + team_id=ateam.pk, + date_ranges=expected_ranges, + batch_export_model=batch_export_model, + exclude_events=exclude_events, + properties_data_type=properties_data_type, + sort_key="event", + expected_duplicates_threshold=0.1, + ) + + +async def test_insert_into_redshift_activity_completes_range( + clickhouse_client, + activity_environment, + psycopg_connection, + redshift_config, + exclude_events, + generate_test_data, + data_interval_start, + data_interval_end, + ateam, +): + """Test we complete a full range of data into a Redshift table when resuming. + + We run two activities: + 1. First activity, up to (and including) the cutoff event. + 2. Second activity with a heartbeat detail matching the cutoff event. + + This simulates the batch export resuming from a failed execution. The full range + should be completed (with a duplicate on the cutoff event) after both activities + are done. + """ + batch_export_model = BatchExportModel(name="events", schema=None) + properties_data_type = "varchar" + + events_to_export_created, _ = generate_test_data + events_to_export_created.sort(key=operator.itemgetter("inserted_at")) + + cutoff_event = events_to_export_created[len(events_to_export_created) // 2 : len(events_to_export_created) // 2 + 1] + assert len(cutoff_event) == 1 + cutoff_event = cutoff_event[0] + cutoff_data_interval_end = dt.datetime.fromisoformat(cutoff_event["inserted_at"]).replace(tzinfo=dt.UTC) + + insert_inputs = RedshiftInsertInputs( + team_id=ateam.pk, + table_name=f"test_insert_activity_table_{ateam.pk}", + data_interval_start=data_interval_start.isoformat(), + # The extra second is because the upper range select is exclusive and + # we want cutoff to be the last event included. + data_interval_end=(cutoff_data_interval_end + dt.timedelta(seconds=1)).isoformat(), + exclude_events=exclude_events, + batch_export_model=batch_export_model, + properties_data_type=properties_data_type, + **redshift_config, + ) + + await activity_environment.run(insert_into_redshift_activity, insert_inputs) + + done_ranges = [ + ( + data_interval_start.isoformat(), + cutoff_data_interval_end.isoformat(), + ), + ] + workflow_id = uuid.uuid4() + + fake_info = activity.Info( + activity_id="insert-into-bigquery-activity", + activity_type="unknown", + current_attempt_scheduled_time=dt.datetime.now(dt.UTC), + workflow_id=str(workflow_id), + workflow_type="bigquery-export", + workflow_run_id=str(uuid.uuid4()), + attempt=1, + heartbeat_timeout=dt.timedelta(seconds=1), + heartbeat_details=[done_ranges], + is_local=False, + schedule_to_close_timeout=dt.timedelta(seconds=10), + scheduled_time=dt.datetime.now(dt.UTC), + start_to_close_timeout=dt.timedelta(seconds=20), + started_time=dt.datetime.now(dt.UTC), + task_queue="test", + task_token=b"test", + workflow_namespace="default", + ) + + activity_environment.info = fake_info + + insert_inputs = RedshiftInsertInputs( + team_id=ateam.pk, + table_name=f"test_insert_activity_table_{ateam.pk}", + data_interval_start=data_interval_start.isoformat(), + data_interval_end=data_interval_end.isoformat(), + exclude_events=exclude_events, + batch_export_model=batch_export_model, + properties_data_type=properties_data_type, + **redshift_config, + ) + + await activity_environment.run(insert_into_redshift_activity, insert_inputs) + + await assert_clickhouse_records_in_redshfit( + redshift_connection=psycopg_connection, + clickhouse_client=clickhouse_client, + schema_name=redshift_config["schema"], + table_name=f"test_insert_activity_table_{ateam.pk}", + team_id=ateam.pk, + date_ranges=[(data_interval_start, data_interval_end)], + batch_export_model=batch_export_model, + exclude_events=exclude_events, + properties_data_type=properties_data_type, + sort_key="event", + expected_duplicates_threshold=0.1, + ) + + @pytest.fixture def table_name(ateam, interval): return f"test_workflow_table_{ateam.pk}_{interval}" @@ -421,7 +661,7 @@ async def test_redshift_export_workflow( elif model is not None: batch_export_schema = model - workflow_id = str(uuid4()) + workflow_id = str(uuid.uuid4()) inputs = RedshiftBatchExportInputs( team_id=ateam.pk, batch_export_id=str(redshift_batch_export.id), @@ -470,8 +710,7 @@ async def test_redshift_export_workflow( schema_name=redshift_config["schema"], table_name=table_name, team_id=ateam.pk, - data_interval_start=data_interval_start, - data_interval_end=data_interval_end, + date_ranges=[(data_interval_start, data_interval_end)], batch_export_model=model, exclude_events=exclude_events, sort_key="person_id" if batch_export_model is not None and batch_export_model.name == "persons" else "event", @@ -495,11 +734,13 @@ def test_remove_escaped_whitespace_recursive(value, expected): assert remove_escaped_whitespace_recursive(value) == expected -async def test_redshift_export_workflow_handles_insert_activity_errors(ateam, redshift_batch_export, interval): +async def test_redshift_export_workflow_handles_insert_activity_errors( + event_loop, ateam, redshift_batch_export, interval +): """Test that Redshift Export Workflow can gracefully handle errors when inserting Redshift data.""" data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") - workflow_id = str(uuid4()) + workflow_id = str(uuid.uuid4()) inputs = RedshiftBatchExportInputs( team_id=ateam.pk, batch_export_id=str(redshift_batch_export.id), @@ -531,6 +772,7 @@ async def insert_into_redshift_activity_mocked(_: RedshiftInsertInputs) -> str: id=workflow_id, task_queue=settings.TEMPORAL_TASK_QUEUE, retry_policy=RetryPolicy(maximum_attempts=1), + execution_timeout=dt.timedelta(seconds=20), ) runs = await afetch_batch_export_runs(batch_export_id=redshift_batch_export.id) @@ -548,7 +790,7 @@ async def test_redshift_export_workflow_handles_insert_activity_non_retryable_er """Test that Redshift Export Workflow can gracefully handle non-retryable errors when inserting Redshift data.""" data_interval_end = dt.datetime.fromisoformat("2023-04-25T14:30:00.000000+00:00") - workflow_id = str(uuid4()) + workflow_id = str(uuid.uuid4()) inputs = RedshiftBatchExportInputs( team_id=ateam.pk, batch_export_id=str(redshift_batch_export.id), diff --git a/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py index 576697d2e47ec..f13b4132a1bd8 100644 --- a/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py @@ -28,7 +28,7 @@ ) from posthog.temporal.batch_exports.s3_batch_export import ( FILE_FORMAT_EXTENSIONS, - HeartbeatDetails, + S3HeartbeatDetails, IntermittentUploadPartTimeoutError, S3BatchExportInputs, S3BatchExportWorkflow, @@ -1010,6 +1010,12 @@ async def test_s3_export_workflow_with_minio_bucket_and_custom_key_prefix( ) +class RetryableTestException(Exception): + """An exception to be raised during tests""" + + pass + + async def test_s3_export_workflow_handles_insert_activity_errors(ateam, s3_batch_export, interval): """Test S3BatchExport Workflow can handle errors from executing the insert into S3 activity. @@ -1028,7 +1034,7 @@ async def test_s3_export_workflow_handles_insert_activity_errors(ateam, s3_batch @activity.defn(name="insert_into_s3_activity") async def insert_into_s3_activity_mocked(_: S3InsertInputs) -> str: - raise ValueError("A useful error message") + raise RetryableTestException("A useful error message") async with await WorkflowEnvironment.start_time_skipping() as activity_environment: async with Worker( @@ -1056,7 +1062,7 @@ async def insert_into_s3_activity_mocked(_: S3InsertInputs) -> str: run = runs[0] assert run.status == "FailedRetryable" - assert run.latest_error == "ValueError: A useful error message" + assert run.latest_error == "RetryableTestException: A useful error message" assert run.records_completed is None @@ -1387,14 +1393,14 @@ async def test_insert_into_s3_activity_heartbeats( inserted_at=part_inserted_at, ) - heartbeat_details = [] + heartbeat_details: list[S3HeartbeatDetails] = [] def track_hearbeat_details(*details): """Record heartbeat details received.""" nonlocal heartbeat_details - details = HeartbeatDetails.from_activity_details(details) - heartbeat_details.append(details) + s3_details = S3HeartbeatDetails.from_activity_details(details) + heartbeat_details.append(s3_details) activity_environment.on_heartbeat = track_hearbeat_details @@ -1415,11 +1421,13 @@ def track_hearbeat_details(*details): assert len(heartbeat_details) > 0 - for detail in heartbeat_details: - last_uploaded_part_dt = dt.datetime.fromisoformat(detail.last_uploaded_part_timestamp) - assert last_uploaded_part_dt == data_interval_end - s3_batch_export.interval_time_delta / len( - detail.upload_state.parts - ) + detail = heartbeat_details[-1] + + assert detail.upload_state is not None + assert len(detail.upload_state.parts) == 3 + assert len(detail.done_ranges) == 1 + + assert detail.done_ranges[0] == (data_interval_start, data_interval_end) await assert_clickhouse_records_in_s3( s3_compatible_client=minio_client, diff --git a/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py index 8c6d944fd3394..e99ef3f1ca350 100644 --- a/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_snowflake_batch_export_workflow.py @@ -1631,7 +1631,13 @@ def capture_heartbeat_details(*details): ) -@pytest.mark.parametrize("details", [(str(dt.datetime.now()), 1)]) +@pytest.mark.parametrize( + "details", + [ + ([(dt.datetime.now().isoformat(), dt.datetime.now().isoformat())], 1), + ([(dt.datetime.now().isoformat(), dt.datetime.now().isoformat())],), + ], +) def test_snowflake_heartbeat_details_parses_from_tuple(details): class FakeActivity: def info(self): @@ -1642,6 +1648,16 @@ def __init__(self): self.heartbeat_details = details snowflake_details = SnowflakeHeartbeatDetails.from_activity(FakeActivity()) + expected_done_ranges = details[0] + + assert snowflake_details.done_ranges == [ + ( + dt.datetime.fromisoformat(expected_done_ranges[0][0]), + dt.datetime.fromisoformat(expected_done_ranges[0][1]), + ) + ] - assert snowflake_details.last_inserted_at == dt.datetime.fromisoformat(details[0]) - assert snowflake_details.file_no == details[1] + if len(details) >= 2: + assert snowflake_details.file_no == details[1] + else: + assert snowflake_details.file_no == 0 diff --git a/posthog/temporal/tests/batch_exports/test_temporary_file.py b/posthog/temporal/tests/batch_exports/test_temporary_file.py index 900ca5e9d3fed..5fb20241ee195 100644 --- a/posthog/temporal/tests/batch_exports/test_temporary_file.py +++ b/posthog/temporal/tests/batch_exports/test_temporary_file.py @@ -11,7 +11,7 @@ BatchExportTemporaryFile, CSVBatchExportWriter, JSONLBatchExportWriter, - LastInsertedAt, + DateRange, ParquetBatchExportWriter, json_dumps_bytes, ) @@ -209,7 +209,9 @@ def test_batch_export_temporary_file_write_records_to_tsv(records): { "event": pa.array(["test-event-0", "test-event-1", "test-event-2"]), "properties": pa.array(['{"prop_0": 1, "prop_1": 2}', "{}", "null"]), - "_inserted_at": pa.array([0, 1, 2]), + "_inserted_at": pa.array( + [dt.datetime.fromtimestamp(0), dt.datetime.fromtimestamp(1), dt.datetime.fromtimestamp(2)] + ), } ) ] @@ -223,20 +225,20 @@ def test_batch_export_temporary_file_write_records_to_tsv(records): async def test_jsonl_writer_writes_record_batches(record_batch): """Test record batches are written as valid JSONL.""" in_memory_file_obj = io.BytesIO() - inserted_ats_seen: list[LastInsertedAt] = [] + date_ranges_seen: list[DateRange] = [] async def store_in_memory_on_flush( batch_export_file, records_since_last_flush, bytes_since_last_flush, flush_counter, - last_inserted_at, + last_date_range, is_last, error, ): assert writer.records_since_last_flush == record_batch.num_rows in_memory_file_obj.write(batch_export_file.read()) - inserted_ats_seen.append(last_inserted_at) + date_ranges_seen.append(last_date_range) writer = JSONLBatchExportWriter(max_bytes=1, flush_callable=store_in_memory_on_flush) @@ -257,7 +259,9 @@ async def store_in_memory_on_flush( assert "_inserted_at" not in written_jsonl assert written_jsonl == {k: v for k, v in expected_jsonl.items() if k != "_inserted_at"} - assert inserted_ats_seen == [record_batch.column("_inserted_at")[-1].as_py()] + assert date_ranges_seen == [ + (record_batch.column("_inserted_at")[0].as_py(), record_batch.column("_inserted_at")[-1].as_py()) + ] @pytest.mark.parametrize( @@ -268,19 +272,19 @@ async def store_in_memory_on_flush( async def test_csv_writer_writes_record_batches(record_batch): """Test record batches are written as valid CSV.""" in_memory_file_obj = io.StringIO() - inserted_ats_seen = [] + date_ranges_seen: list[DateRange] = [] async def store_in_memory_on_flush( batch_export_file, records_since_last_flush, bytes_since_last_flush, flush_counter, - last_inserted_at, + last_date_range, is_last, error, ): in_memory_file_obj.write(batch_export_file.read().decode("utf-8")) - inserted_ats_seen.append(last_inserted_at) + date_ranges_seen.append(last_date_range) schema_columns = [column_name for column_name in record_batch.column_names if column_name != "_inserted_at"] writer = CSVBatchExportWriter(max_bytes=1, field_names=schema_columns, flush_callable=store_in_memory_on_flush) @@ -304,7 +308,9 @@ async def store_in_memory_on_flush( assert "_inserted_at" not in written_csv_row assert written_csv_row == list({k: v for k, v in expected_dict.items() if k != "_inserted_at"}.values()) - assert inserted_ats_seen == [record_batch.column("_inserted_at")[-1].as_py()] + assert date_ranges_seen == [ + (record_batch.column("_inserted_at")[0].as_py(), record_batch.column("_inserted_at")[-1].as_py()) + ] @pytest.mark.parametrize( @@ -315,19 +321,19 @@ async def store_in_memory_on_flush( async def test_parquet_writer_writes_record_batches(record_batch): """Test record batches are written as valid Parquet.""" in_memory_file_obj = io.BytesIO() - inserted_ats_seen = [] + date_ranges_seen: list[DateRange] = [] async def store_in_memory_on_flush( batch_export_file, records_since_last_flush, bytes_since_last_flush, flush_counter, - last_inserted_at, + last_date_range, is_last, error, ): in_memory_file_obj.write(batch_export_file.read()) - inserted_ats_seen.append(last_inserted_at) + date_ranges_seen.append(last_date_range) schema_columns = [column_name for column_name in record_batch.column_names if column_name != "_inserted_at"] @@ -353,9 +359,9 @@ async def store_in_memory_on_flush( # NOTE: Parquet gets flushed twice due to the extra flush at the end for footer bytes, so our mock function # will see this value twice. - assert inserted_ats_seen == [ - record_batch.column("_inserted_at")[-1].as_py(), - record_batch.column("_inserted_at")[-1].as_py(), + assert date_ranges_seen == [ + (record_batch.column("_inserted_at")[0].as_py(), record_batch.column("_inserted_at")[-1].as_py()), + (record_batch.column("_inserted_at")[0].as_py(), record_batch.column("_inserted_at")[-1].as_py()), ] @@ -412,7 +418,7 @@ async def track_flushes(*args, **kwargs): assert writer.bytes_since_last_flush == writer.batch_export_file.bytes_since_last_reset assert writer.records_since_last_flush == record_batch.num_rows - await writer.flush(dt.datetime.now()) + await writer.flush() assert flush_counter == 1 assert writer.batch_export_file.tell() == 0 @@ -427,7 +433,7 @@ async def track_flushes(*args, **kwargs): async def test_jsonl_writer_deals_with_web_vitals(): """Test old $web_vitals record batches are written as valid JSONL.""" in_memory_file_obj = io.BytesIO() - inserted_ats_seen: list[LastInsertedAt] = [] + date_ranges_seen: list[DateRange] = [] record_batch = pa.RecordBatch.from_pydict( { @@ -442,7 +448,7 @@ async def test_jsonl_writer_deals_with_web_vitals(): } ] ), - "_inserted_at": pa.array([0]), + "_inserted_at": pa.array([dt.datetime.fromtimestamp(0)]), } ) @@ -451,13 +457,13 @@ async def store_in_memory_on_flush( records_since_last_flush, bytes_since_last_flush, flush_counter, - last_inserted_at, + last_date_range, is_last, error, ): assert writer.records_since_last_flush == record_batch.num_rows in_memory_file_obj.write(batch_export_file.read()) - inserted_ats_seen.append(last_inserted_at) + date_ranges_seen.append(last_date_range) writer = JSONLBatchExportWriter(max_bytes=1, flush_callable=store_in_memory_on_flush) @@ -479,20 +485,22 @@ async def store_in_memory_on_flush( del expected_jsonl["properties"]["$web_vitals_INP_event"]["attribution"]["interactionTargetElement"] assert written_jsonl == {k: v for k, v in expected_jsonl.items() if k != "_inserted_at"} - assert inserted_ats_seen == [record_batch.column("_inserted_at")[-1].as_py()] + assert date_ranges_seen == [ + (record_batch.column("_inserted_at")[0].as_py(), record_batch.column("_inserted_at")[-1].as_py()) + ] @pytest.mark.asyncio async def test_jsonl_writer_deals_with_nested_user_events(): """Test very nested user event record batches are written as valid JSONL.""" in_memory_file_obj = io.BytesIO() - inserted_ats_seen: list[LastInsertedAt] = [] + date_ranges_seen: list[DateRange] = [] record_batch = pa.RecordBatch.from_pydict( { "event": pa.array(["my_event"]), "properties": pa.array([{"we_have_to_go_deeper": json.loads("[" * 256 + "]" * 256)}]), - "_inserted_at": pa.array([0]), + "_inserted_at": pa.array([dt.datetime.fromtimestamp(0)]), } ) @@ -501,13 +509,13 @@ async def store_in_memory_on_flush( records_since_last_flush, bytes_since_last_flush, flush_counter, - last_inserted_at, + last_date_range, is_last, error, ): assert writer.records_since_last_flush == record_batch.num_rows in_memory_file_obj.write(batch_export_file.read()) - inserted_ats_seen.append(last_inserted_at) + date_ranges_seen.append(last_date_range) writer = JSONLBatchExportWriter(max_bytes=1, flush_callable=store_in_memory_on_flush) @@ -525,4 +533,6 @@ async def store_in_memory_on_flush( assert "_inserted_at" not in written_jsonl assert written_jsonl == {k: v for k, v in expected_jsonl.items() if k != "_inserted_at"} - assert inserted_ats_seen == [record_batch.column("_inserted_at")[-1].as_py()] + assert date_ranges_seen == [ + (record_batch.column("_inserted_at")[0].as_py(), record_batch.column("_inserted_at")[-1].as_py()) + ] diff --git a/posthog/temporal/tests/batch_exports/utils.py b/posthog/temporal/tests/batch_exports/utils.py index 2c48c26248dc9..8ea929d38eca9 100644 --- a/posthog/temporal/tests/batch_exports/utils.py +++ b/posthog/temporal/tests/batch_exports/utils.py @@ -16,7 +16,6 @@ async def mocked_start_batch_export_run(inputs: StartBatchExportRunInputs) -> st data_interval_start=inputs.data_interval_start, data_interval_end=inputs.data_interval_end, status=BatchExportRun.Status.STARTING, - records_total_count=1, ) return str(run.id) diff --git a/posthog/temporal/tests/utils/events.py b/posthog/temporal/tests/utils/events.py index 399b00ec7347c..83a76f26bde65 100644 --- a/posthog/temporal/tests/utils/events.py +++ b/posthog/temporal/tests/utils/events.py @@ -1,6 +1,7 @@ """Test utilities that deal with test event generation.""" import datetime as dt +import itertools import json import random import typing @@ -48,41 +49,52 @@ def generate_test_events( distinct_ids: list[str] | None = None, ): """Generate a list of events for testing.""" - _timestamp = random.choice(possible_datetimes) - - if inserted_at == "_timestamp": - inserted_at_value = _timestamp.strftime("%Y-%m-%d %H:%M:%S.%f") - elif inserted_at == "random": - inserted_at_value = random.choice(possible_datetimes).strftime("%Y-%m-%d %H:%M:%S.%f") - elif inserted_at is None: - inserted_at_value = None + datetime_sample = random.sample(possible_datetimes, len(possible_datetimes)) + datetime_cycle = itertools.cycle(datetime_sample) + _timestamp = next(datetime_cycle) + + if distinct_ids: + distinct_id_sample = random.sample(distinct_ids, len(distinct_ids)) + distinct_id_cycle = itertools.cycle(distinct_id_sample) else: - if not isinstance(inserted_at, dt.datetime): - raise ValueError(f"Unsupported value for inserted_at: '{inserted_at}'") - inserted_at_value = inserted_at.strftime("%Y-%m-%d %H:%M:%S.%f") - - events: list[EventValues] = [ - { - "_timestamp": _timestamp.strftime("%Y-%m-%d %H:%M:%S"), - "created_at": random.choice(possible_datetimes).strftime("%Y-%m-%d %H:%M:%S.%f"), - "distinct_id": random.choice(distinct_ids) if distinct_ids else str(uuid.uuid4()), - "elements": json.dumps("css selectors;"), - "elements_chain": "css selectors;", - "event": event_name.format(i=i), - "inserted_at": inserted_at_value, - "person_id": str(uuid.uuid4()), - "person_properties": person_properties, - "properties": properties, - "team_id": team_id, - "timestamp": random.choice(possible_datetimes).strftime("%Y-%m-%d %H:%M:%S.%f"), - "uuid": str(uuid.uuid4()), - "ip": ip, - "site_url": site_url, - "set": set_field, - "set_once": set_once, - } - for i in range(start, count + start) - ] + distinct_id_cycle = None + + def compute_inserted_at(): + if inserted_at == "_timestamp": + inserted_at_value = _timestamp.strftime("%Y-%m-%d %H:%M:%S.%f") + elif inserted_at == "random": + inserted_at_value = next(datetime_cycle).strftime("%Y-%m-%d %H:%M:%S.%f") + elif inserted_at is None: + inserted_at_value = None + else: + if not isinstance(inserted_at, dt.datetime): + raise ValueError(f"Unsupported value for inserted_at: '{inserted_at}'") + inserted_at_value = inserted_at.strftime("%Y-%m-%d %H:%M:%S.%f") + return inserted_at_value + + events: list[EventValues] = [] + for i in range(start, count + start): + events.append( + { + "_timestamp": _timestamp.strftime("%Y-%m-%d %H:%M:%S"), + "created_at": next(datetime_cycle).strftime("%Y-%m-%d %H:%M:%S.%f"), + "distinct_id": next(distinct_id_cycle) if distinct_id_cycle else str(uuid.uuid4()), + "elements": json.dumps("css selectors;"), + "elements_chain": "css selectors;", + "event": event_name.format(i=i), + "inserted_at": compute_inserted_at(), + "person_id": str(uuid.uuid4()), + "person_properties": person_properties, + "properties": properties, + "team_id": team_id, + "timestamp": next(datetime_cycle).strftime("%Y-%m-%d %H:%M:%S.%f"), + "uuid": str(uuid.uuid4()), + "ip": ip, + "site_url": site_url, + "set": set_field, + "set_once": set_once, + } + ) return events @@ -140,7 +152,7 @@ async def generate_test_events_in_clickhouse( event_name: str = "test-{i}", properties: dict | None = None, person_properties: dict | None = None, - inserted_at: str | dt.datetime | None = "_timestamp", + inserted_at: str | dt.datetime | None = "random", distinct_ids: list[str] | None = None, duplicate: bool = False, batch_size: int = 10000,