From abbd27f5fc48668cc973c1ccf32283996599039c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Mon, 18 Mar 2024 14:02:41 +0100 Subject: [PATCH 01/14] refactor: Support for multiple file formats in batch exports --- posthog/batch_exports/service.py | 1 + .../temporal/batch_exports/batch_exports.py | 164 +++++++++++++++++ .../temporal/batch_exports/s3_batch_export.py | 172 ++++++++++++++---- .../test_s3_batch_export_workflow.py | 114 ++++++++++-- 4 files changed, 397 insertions(+), 54 deletions(-) diff --git a/posthog/batch_exports/service.py b/posthog/batch_exports/service.py index c26be9a77ed1a..11859e9b4ff82 100644 --- a/posthog/batch_exports/service.py +++ b/posthog/batch_exports/service.py @@ -90,6 +90,7 @@ class S3BatchExportInputs: kms_key_id: str | None = None batch_export_schema: BatchExportSchema | None = None endpoint_url: str | None = None + file_format: str = "JSONLines" @dataclass diff --git a/posthog/temporal/batch_exports/batch_exports.py b/posthog/temporal/batch_exports/batch_exports.py index c776e1f245ef3..d2e42f3078efd 100644 --- a/posthog/temporal/batch_exports/batch_exports.py +++ b/posthog/temporal/batch_exports/batch_exports.py @@ -11,6 +11,7 @@ import brotli import orjson import pyarrow as pa +import pyarrow.parquet as pq from asgiref.sync import sync_to_async from django.conf import settings from temporalio import activity, exceptions, workflow @@ -482,6 +483,169 @@ def reset(self): self.records_since_last_reset = 0 +FlushCallable = collections.abc.Callable[ + [BatchExportTemporaryFile, int, int, dt.datetime, bool], collections.abc.Awaitable[None] +] + + +class BatchExportWriter(typing.Protocol): + batch_export_file: BatchExportTemporaryFile + flush_callable: FlushCallable + records_total: int = 0 + records_since_last_flush: int = 0 + max_bytes: int = 0 + last_inserted_at: dt.datetime | None = None + + async def __aenter__(self): + """Context-manager protocol enter method.""" + self.batch_export_file.__enter__() + return self + + async def __aexit__(self, exc, value, tb): + """Context-manager protocol exit method.""" + if self.last_inserted_at is not None and self.records_since_last_flush > 0: + await self.flush(self.last_inserted_at, is_last=True) + return self.batch_export_file.__exit__(exc, value, tb) + + def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: + ... + + async def write_record_batch(self, record_batch: pa.RecordBatch) -> None: + last_inserted_at = record_batch.column("_inserted_at")[0].as_py() + + column_names = record_batch.column_names + column_names.pop(column_names.index("_inserted_at")) + + self._write_record_batch(record_batch.select(column_names)) + + self.last_inserted_at = last_inserted_at + + if self.bytes_since_last_flush >= self.max_bytes: + await self.flush(last_inserted_at) + + async def flush(self, last_inserted_at: dt.datetime, is_last: bool = False) -> None: + await self.flush_callable( + self.batch_export_file, + self.records_since_last_flush, + self.bytes_since_last_flush, + last_inserted_at, + is_last, + ) + self.batch_export_file.reset() + self.records_since_last_flush = 0 + + @property + def bytes_since_last_flush(self) -> int: + return self.batch_export_file.bytes_since_last_reset + + +class JSONLBatchExportWriter(BatchExportWriter): + def __init__( + self, + max_bytes: int, + flush_callable: FlushCallable, + compression: None | str = None, + default: typing.Callable = str, + ): + self.batch_export_file = BatchExportTemporaryFile(compression=compression) + self.flush_callable = flush_callable + + self.default = default + + self.max_bytes = max_bytes + self.last_inserted_at = None + + def write(self, content: bytes) -> int: + n = self.batch_export_file.write(orjson.dumps(content, default=str) + b"\n") + return n + + def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: + """Write records to a temporary file as JSONL.""" + for record in record_batch.to_pylist(): + self.write(record) + + self.records_total += 1 + self.records_since_last_flush += 1 + + +class CSVBatchExportWriter(csv.DictWriter, BatchExportWriter): + def __init__( + self, + max_bytes: int, + flush_callable: FlushCallable, + field_names: collections.abc.Sequence[str], + extras_action: typing.Literal["raise", "ignore"] = "ignore", + delimiter: str = ",", + quote_char: str = '"', + escape_char: str | None = "\\", + line_terminator: str = "\n", + quoting=csv.QUOTE_NONE, + compression: str | None = None, + ): + self.batch_export_file = BatchExportTemporaryFile(compression=compression) + self.flush_callable = flush_callable + + super().__init__( + self.batch_export_file, + fieldnames=field_names, + extrasaction=extras_action, + delimiter=delimiter, + quotechar=quote_char, + escapechar=escape_char, + quoting=quoting, + lineterminator=line_terminator, + ) + + self.max_bytes = max_bytes + self.last_inserted_at = None + + def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: + """Write records to a temporary file as JSONL.""" + self.writerows(record_batch.to_pylist()) + self.records_total += record_batch.num_rows + self.records_since_last_flush += record_batch.num_rows + + +class ParquetBatchExportWriter(pq.ParquetWriter, BatchExportWriter): + def __init__( + self, + max_bytes: int, + flush_callable: FlushCallable, + schema: pa.Schema, + version: str = "2.6", + use_dictionary: bool = True, + compression: str | None = "snappy", + compression_level: int | None = None, + ): + self.batch_export_file = BatchExportTemporaryFile(compression=None) # Handle compression in ParquetWriter + self.flush_callable = flush_callable + + super().__init__( + self.batch_export_file, + schema=schema, + version=version, + use_dictionary=use_dictionary, + compression=compression, + compression_level=compression_level, + ) + + self.max_bytes = max_bytes + self.last_inserted_at = None + + async def __aexit__(self, exc, value, tb): + """Close underlying Parquet writer to include footer bytes before flushing last.""" + self.writer.close() + self.is_open = False + + await super().__aexit__(exc, value, tb) + + def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: + """Write records to a temporary file as JSONL.""" + self.write_batch(record_batch.select(self.schema.names)) + self.records_total += record_batch.num_rows + self.records_since_last_flush += record_batch.num_rows + + @dataclasses.dataclass class CreateBatchExportRunInputs: """Inputs to the create_export_run activity. diff --git a/posthog/temporal/batch_exports/s3_batch_export.py b/posthog/temporal/batch_exports/s3_batch_export.py index 4d99cbeffd7c3..13e1a60bba999 100644 --- a/posthog/temporal/batch_exports/s3_batch_export.py +++ b/posthog/temporal/batch_exports/s3_batch_export.py @@ -1,4 +1,5 @@ import asyncio +import collections.abc import contextlib import datetime as dt import io @@ -8,6 +9,8 @@ from dataclasses import dataclass import aioboto3 +import orjson +import pyarrow as pa from django.conf import settings from temporalio import activity, workflow from temporalio.common import RetryPolicy @@ -18,6 +21,8 @@ from posthog.temporal.batch_exports.batch_exports import ( BatchExportTemporaryFile, CreateBatchExportRunInputs, + JSONLBatchExportWriter, + ParquetBatchExportWriter, UpdateBatchExportRunStatusInputs, create_export_run, default_fields, @@ -30,6 +35,7 @@ get_bytes_exported_metric, get_rows_exported_metric, ) +from posthog.temporal.batch_exports.utils import peek_first_and_rewind from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.logger import bind_temporal_worker_logger @@ -50,19 +56,31 @@ def get_allowed_template_variables(inputs) -> dict[str, str]: } +FILE_FORMAT_EXTENSIONS = { + "Parquet": "parquet", + "JSONLines": "jsonl", +} + +COMPRESSION_EXTENSIONS = { + "gzip": "gz", + "snappy": "sz", + "brotli": "br", + "ztsd": "zst", + "lz4": "lz4", +} + + def get_s3_key(inputs) -> str: """Return an S3 key given S3InsertInputs.""" template_variables = get_allowed_template_variables(inputs) key_prefix = inputs.prefix.format(**template_variables) + file_extension = FILE_FORMAT_EXTENSIONS[inputs.file_format] base_file_name = f"{inputs.data_interval_start}-{inputs.data_interval_end}" - match inputs.compression: - case "gzip": - file_name = base_file_name + ".jsonl.gz" - case "brotli": - file_name = base_file_name + ".jsonl.br" - case _: - file_name = base_file_name + ".jsonl" + if inputs.compression is not None: + file_name = base_file_name + f".{file_extension}.{COMPRESSION_EXTENSIONS[inputs.compression]}" + else: + file_name = base_file_name + f".{file_extension}" key = posixpath.join(key_prefix, file_name) @@ -311,6 +329,8 @@ class S3InsertInputs: kms_key_id: str | None = None batch_export_schema: BatchExportSchema | None = None endpoint_url: str | None = None + # TODO: In Python 3.11, this could be a enum.StrEnum. + file_format: str = "JSONLines" async def initialize_and_resume_multipart_upload(inputs: S3InsertInputs) -> tuple[S3MultiPartUpload, str]: @@ -466,50 +486,125 @@ async def worker_shutdown_handler(): asyncio.create_task(worker_shutdown_handler()) - record = None - async with s3_upload as s3_upload: - with BatchExportTemporaryFile(compression=inputs.compression) as local_results_file: + + async def flush_to_s3( + local_results_file, + records_since_last_flush: int, + bytes_since_last_flush: int, + last_inserted_at: dt.datetime, + last: bool, + ): + nonlocal last_uploaded_part_timestamp + + logger.debug( + "Uploading %s part %s containing %s records with size %s bytes", + "last " if last else "", + s3_upload.part_number + 1, + records_since_last_flush, + bytes_since_last_flush, + ) + + await s3_upload.upload_part(local_results_file) + rows_exported.add(records_since_last_flush) + bytes_exported.add(bytes_since_last_flush) + + last_uploaded_part_timestamp = str(last_inserted_at) + activity.heartbeat(last_uploaded_part_timestamp, s3_upload.to_state()) + + if inputs.file_format == "Parquet": + first_record_batch, record_iterator = peek_first_and_rewind(record_iterator) + first_record_batch = cast_record_batch_json_columns(first_record_batch) + column_names = first_record_batch.column_names + column_names.pop(column_names.index("_inserted_at")) + + schema = pa.schema( + # NOTE: For some reason, some batches set non-nullable fields as non-nullable, whereas other + # record batches have them as nullable. + # Until we figure it out, we set all fields to nullable. There are some fields we know + # are not nullable, but I'm opting for the more flexible option until we out why schemas differ + # between batches. + [field.with_nullable(True) for field in first_record_batch.select(column_names).schema] + ) + + writer = ParquetBatchExportWriter( + max_bytes=settings.BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES, + flush_callable=flush_to_s3, + compression=inputs.compression, + schema=schema, + ) + else: + writer = JSONLBatchExportWriter( + max_bytes=settings.BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES, + flush_callable=flush_to_s3, + compression=inputs.compression, + ) + + async with writer: rows_exported = get_rows_exported_metric() bytes_exported = get_bytes_exported_metric() - async def flush_to_s3(last_uploaded_part_timestamp: str, last=False): - logger.debug( - "Uploading %s part %s containing %s records with size %s bytes", - "last " if last else "", - s3_upload.part_number + 1, - local_results_file.records_since_last_reset, - local_results_file.bytes_since_last_reset, - ) + for record_batch in record_iterator: + record_batch = cast_record_batch_json_columns(record_batch) - await s3_upload.upload_part(local_results_file) - rows_exported.add(local_results_file.records_since_last_reset) - bytes_exported.add(local_results_file.bytes_since_last_reset) + await writer.write_record_batch(record_batch) - activity.heartbeat(last_uploaded_part_timestamp, s3_upload.to_state()) + await s3_upload.complete() - for record_batch in record_iterator: - for record in record_batch.to_pylist(): - for json_column in ("properties", "person_properties", "set", "set_once"): - if (json_str := record.get(json_column, None)) is not None: - record[json_column] = json.loads(json_str) + return local_results_file.records_total - inserted_at = record.pop("_inserted_at") - local_results_file.write_records_to_jsonl([record]) +def cast_record_batch_json_columns( + record_batch: pa.RecordBatch, + json_columns: collections.abc.Sequence = ("properties", "person_properties", "set", "set_once"), +) -> pa.RecordBatch: + """Cast json_columns in record_batch to JsonType. - if local_results_file.tell() > settings.BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES: - last_uploaded_part_timestamp = str(inserted_at) - await flush_to_s3(last_uploaded_part_timestamp) - local_results_file.reset() + We return a new RecordBatch with any json_columns replaced by fields casted to JsonType. + Casting is not copying the underlying array buffers, so memory usage does not increase when creating + the new array or the new record batch. + """ + column_names = set(record_batch.column_names) + intersection = column_names & set(json_columns) + + casted_arrays = [] + for array in record_batch.select(intersection): + if pa.types.is_string(array.type): + casted_array = array.cast(JsonType()) + casted_arrays.append(casted_array) + + remaining_column_names = list(column_names - intersection) + return pa.RecordBatch.from_arrays( + record_batch.select(remaining_column_names).columns + casted_arrays, + names=remaining_column_names + list(intersection), + ) - if local_results_file.tell() > 0 and record is not None: - last_uploaded_part_timestamp = str(inserted_at) - await flush_to_s3(last_uploaded_part_timestamp, last=True) - await s3_upload.complete() +class JsonScalar(pa.ExtensionScalar): + """Represents a JSON binary string.""" - return local_results_file.records_total + def as_py(self) -> dict | None: + if self.value: + return orjson.loads(self.value.as_py()) + else: + return None + + +class JsonType(pa.ExtensionType): + """Type for JSON binary strings.""" + + def __init__(self): + super().__init__(pa.binary(), "json") + + def __arrow_ext_serialize__(self): + return b"" + + @classmethod + def __arrow_ext_deserialize__(self, storage_type, serialized): + return JsonType() + + def __arrow_ext_scalar_class__(self): + return JsonScalar @workflow.defn(name="s3-export") @@ -572,6 +667,7 @@ async def run(self, inputs: S3BatchExportInputs): encryption=inputs.encryption, kms_key_id=inputs.kms_key_id, batch_export_schema=inputs.batch_export_schema, + file_format=inputs.file_format, ) await execute_batch_export_insert_activity( diff --git a/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py index e04e345d11245..3b9ff1dd5c7ac 100644 --- a/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py @@ -10,10 +10,12 @@ import aioboto3 import botocore.exceptions import brotli +import pyarrow.parquet as pq import pytest import pytest_asyncio from django.conf import settings from django.test import override_settings +from pyarrow import fs from temporalio import activity from temporalio.client import WorkflowFailureError from temporalio.common import RetryPolicy @@ -138,6 +140,42 @@ async def minio_client(bucket_name): await minio_client.delete_bucket(Bucket=bucket_name) +async def read_parquet_from_s3(bucket_name: str, key: str) -> list: + async with aioboto3.Session().client("sts") as sts: + try: + await sts.get_caller_identity() + except botocore.exceptions.NoCredentialsError: + s3 = fs.S3FileSystem( + access_key="object_storage_root_user", + secret_key="object_storage_root_password", + endpoint_override=settings.OBJECT_STORAGE_ENDPOINT, + ) + + else: + s3 = fs.S3FileSystem() + + table = pq.read_table(f"{bucket_name}/{key}", filesystem=s3) + + parquet_data = [] + for batch in table.to_batches(): + parquet_data.extend(batch.to_pylist()) + + return parquet_data + + +def read_s3_data_as_json(data: bytes, compression: str | None) -> list: + match compression: + case "gzip": + data = gzip.decompress(data) + case "brotli": + data = brotli.decompress(data) + case _: + pass + + json_data = [json.loads(line) for line in data.decode("utf-8").split("\n") if line] + return json_data + + async def assert_clickhouse_records_in_s3( s3_compatible_client, clickhouse_client: ClickHouseClient, @@ -150,6 +188,7 @@ async def assert_clickhouse_records_in_s3( include_events: list[str] | None = None, batch_export_schema: BatchExportSchema | None = None, compression: str | None = None, + file_format: str = "JSONLines", ): """Assert ClickHouse records are written to JSON in key_prefix in S3 bucket_name. @@ -175,20 +214,16 @@ async def assert_clickhouse_records_in_s3( # Get the object. key = objects["Contents"][0].get("Key") assert key - s3_object = await s3_compatible_client.get_object(Bucket=bucket_name, Key=key) - data = await s3_object["Body"].read() - # Check that the data is correct. - match compression: - case "gzip": - data = gzip.decompress(data) - case "brotli": - data = brotli.decompress(data) - case _: - pass + if file_format == "Parquet": + s3_data = await read_parquet_from_s3(bucket_name, key) - json_data = [json.loads(line) for line in data.decode("utf-8").split("\n") if line] - # Pull out the fields we inserted only + elif file_format == "JSONLines": + s3_object = await s3_compatible_client.get_object(Bucket=bucket_name, Key=key) + data = await s3_object["Body"].read() + s3_data = read_s3_data_as_json(data, compression) + else: + raise ValueError(f"Unsupported file format: {file_format}") if batch_export_schema is not None: schema_column_names = [field["alias"] for field in batch_export_schema["fields"]] @@ -225,9 +260,9 @@ async def assert_clickhouse_records_in_s3( expected_records.append(expected_record) - assert len(json_data) == len(expected_records) - assert json_data[0] == expected_records[0] - assert json_data == expected_records + assert len(s3_data) == len(expected_records) + assert s3_data[0] == expected_records[0] + assert s3_data == expected_records TEST_S3_SCHEMAS: list[BatchExportSchema | None] = [ @@ -255,6 +290,7 @@ async def assert_clickhouse_records_in_s3( @pytest.mark.parametrize("compression", [None, "gzip", "brotli"], indirect=True) @pytest.mark.parametrize("exclude_events", [None, ["test-exclude"]], indirect=True) @pytest.mark.parametrize("batch_export_schema", TEST_S3_SCHEMAS) +@pytest.mark.parametrize("file_format", ["JSONLines", "Parquet"]) async def test_insert_into_s3_activity_puts_data_into_s3( clickhouse_client, bucket_name, @@ -262,6 +298,7 @@ async def test_insert_into_s3_activity_puts_data_into_s3( activity_environment, compression, exclude_events, + file_format, batch_export_schema: BatchExportSchema | None, ): """Test that the insert_into_s3_activity function ends up with data into S3. @@ -339,6 +376,7 @@ async def test_insert_into_s3_activity_puts_data_into_s3( compression=compression, exclude_events=exclude_events, batch_export_schema=batch_export_schema, + file_format=file_format, ) with override_settings( @@ -358,6 +396,7 @@ async def test_insert_into_s3_activity_puts_data_into_s3( exclude_events=exclude_events, include_events=None, compression=compression, + file_format=file_format, ) @@ -1206,6 +1245,49 @@ async def never_finish_activity(_: S3InsertInputs) -> str: ), "nested/prefix/2023-01-01 00:00:00-2023-01-01 01:00:00.jsonl.br", ), + ( + S3InsertInputs( + prefix="/nested/prefix/", + data_interval_start="2023-01-01 00:00:00", + data_interval_end="2023-01-01 01:00:00", + file_format="Parquet", + compression="snappy", + **base_inputs, # type: ignore + ), + "nested/prefix/2023-01-01 00:00:00-2023-01-01 01:00:00.parquet.sz", + ), + ( + S3InsertInputs( + prefix="/nested/prefix/", + data_interval_start="2023-01-01 00:00:00", + data_interval_end="2023-01-01 01:00:00", + file_format="Parquet", + **base_inputs, # type: ignore + ), + "nested/prefix/2023-01-01 00:00:00-2023-01-01 01:00:00.parquet", + ), + ( + S3InsertInputs( + prefix="/nested/prefix/", + data_interval_start="2023-01-01 00:00:00", + data_interval_end="2023-01-01 01:00:00", + compression="gzip", + file_format="Parquet", + **base_inputs, # type: ignore + ), + "nested/prefix/2023-01-01 00:00:00-2023-01-01 01:00:00.parquet.gz", + ), + ( + S3InsertInputs( + prefix="/nested/prefix/", + data_interval_start="2023-01-01 00:00:00", + data_interval_end="2023-01-01 01:00:00", + compression="brotli", + file_format="Parquet", + **base_inputs, # type: ignore + ), + "nested/prefix/2023-01-01 00:00:00-2023-01-01 01:00:00.parquet.br", + ), ], ) def test_get_s3_key(inputs, expected): @@ -1271,7 +1353,7 @@ def assert_heartbeat_details(*details): endpoint_url=settings.OBJECT_STORAGE_ENDPOINT, ) - with override_settings(BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2): + with override_settings(BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES=1, CLICKHOUSE_MAX_BLOCK_SIZE_DEFAULT=1): await activity_environment.run(insert_into_s3_activity, insert_inputs) # This checks that the assert_heartbeat_details function was actually called. From 13e7956001291797c5c5268359133f41c51a3d74 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Tue, 19 Mar 2024 10:32:02 +0100 Subject: [PATCH 02/14] refactor: Prefer composition over inheritance --- .../temporal/batch_exports/batch_exports.py | 127 ++++++++++++------ 1 file changed, 89 insertions(+), 38 deletions(-) diff --git a/posthog/temporal/batch_exports/batch_exports.py b/posthog/temporal/batch_exports/batch_exports.py index d2e42f3078efd..b75b0494d9209 100644 --- a/posthog/temporal/batch_exports/batch_exports.py +++ b/posthog/temporal/batch_exports/batch_exports.py @@ -1,3 +1,4 @@ +import abc import collections.abc import csv import dataclasses @@ -488,13 +489,52 @@ def reset(self): ] -class BatchExportWriter(typing.Protocol): - batch_export_file: BatchExportTemporaryFile - flush_callable: FlushCallable - records_total: int = 0 - records_since_last_flush: int = 0 - max_bytes: int = 0 - last_inserted_at: dt.datetime | None = None +class BatchExportWriter(abc.ABC): + """A temporary file writer to be used by batch export workflows. + + Subclasses should define `_write_record_batch` with the particular intricacies + of the format they are writing as. + + Actual writing calls are passed to the underlying `batch_export_file`. + + Attributes: + batch_export_file: The temporary file we are writing to. + bytes_flush_threshold: Flush the temporary file with the provided `flush_callable` + upon reaching or surpassing this threshold. Keep in mind we write on a RecordBatch + per RecordBatch basis, which means the threshold will be surpassed by at most the + size of a RecordBatch before a flush occurs. + flush_callable: A callback to flush the temporary file when `bytes_flush_treshold` is reached. + The temporary file will be reset after calling `flush_callable`. + records_total: The total number of records (not RecordBatches!) written. + records_since_last_flush: The number of records written since last flush. + last_inserted_at: Latest `_inserted_at` written. This attribute leaks some implementation + details, as we are making two assumptions about the RecordBatches being written: + * We assume RecordBatches are sorted on `_inserted_at`, which currently happens with + an `ORDER BY` clause. + * We assume `_inserted_at` is present, as it's added to all batch export queries. + """ + + def __init__( + self, + batch_export_file: BatchExportTemporaryFile, + flush_callable: FlushCallable, + max_bytes: int, + ): + self.batch_export_file = batch_export_file + self.max_bytes = max_bytes + self.flush_callable = flush_callable + self.records_total = 0 + self.records_since_last_flush = 0 + self.last_inserted_at: dt.datetime | None = None + + @abc.abstractmethod + def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: + """Write a record batch to the underlying `BatchExportTemporaryFile`. + + Subclasses must override this to provide the actual implementation according to the supported + file format. + """ + pass async def __aenter__(self): """Context-manager protocol enter method.""" @@ -502,15 +542,17 @@ async def __aenter__(self): return self async def __aexit__(self, exc, value, tb): - """Context-manager protocol exit method.""" + """Context-manager protocol exit method. + + We flush the latest data available in the file. Subclasses that implement formats that require + written footers should override this method to write the footers before the last flush. + """ if self.last_inserted_at is not None and self.records_since_last_flush > 0: await self.flush(self.last_inserted_at, is_last=True) return self.batch_export_file.__exit__(exc, value, tb) - def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: - ... - async def write_record_batch(self, record_batch: pa.RecordBatch) -> None: + """Issue a record batch write tracking progress and flushing if required.""" last_inserted_at = record_batch.column("_inserted_at")[0].as_py() column_names = record_batch.column_names @@ -524,6 +566,7 @@ async def write_record_batch(self, record_batch: pa.RecordBatch) -> None: await self.flush(last_inserted_at) async def flush(self, last_inserted_at: dt.datetime, is_last: bool = False) -> None: + """Call the provided `flush_callable` and reset underlying file.""" await self.flush_callable( self.batch_export_file, self.records_since_last_flush, @@ -536,10 +579,13 @@ async def flush(self, last_inserted_at: dt.datetime, is_last: bool = False) -> N @property def bytes_since_last_flush(self) -> int: + """Access written bytes from underlying file for convenience.""" return self.batch_export_file.bytes_since_last_reset class JSONLBatchExportWriter(BatchExportWriter): + """A `BatchExportWriter` for JSONLines format.""" + def __init__( self, max_bytes: int, @@ -547,14 +593,14 @@ def __init__( compression: None | str = None, default: typing.Callable = str, ): - self.batch_export_file = BatchExportTemporaryFile(compression=compression) - self.flush_callable = flush_callable + super().__init__( + batch_export_file=BatchExportTemporaryFile(compression=compression), + max_bytes=max_bytes, + flush_callable=flush_callable, + ) self.default = default - self.max_bytes = max_bytes - self.last_inserted_at = None - def write(self, content: bytes) -> int: n = self.batch_export_file.write(orjson.dumps(content, default=str) + b"\n") return n @@ -568,7 +614,9 @@ def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: self.records_since_last_flush += 1 -class CSVBatchExportWriter(csv.DictWriter, BatchExportWriter): +class CSVBatchExportWriter(BatchExportWriter): + """A `BatchExportWriter` for CSV format.""" + def __init__( self, max_bytes: int, @@ -582,10 +630,13 @@ def __init__( quoting=csv.QUOTE_NONE, compression: str | None = None, ): - self.batch_export_file = BatchExportTemporaryFile(compression=compression) - self.flush_callable = flush_callable - super().__init__( + batch_export_file=BatchExportTemporaryFile(compression=compression), + max_bytes=max_bytes, + flush_callable=flush_callable, + ) + + self._csv_writer = csv.DictWriter( self.batch_export_file, fieldnames=field_names, extrasaction=extras_action, @@ -596,17 +647,16 @@ def __init__( lineterminator=line_terminator, ) - self.max_bytes = max_bytes - self.last_inserted_at = None - def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: - """Write records to a temporary file as JSONL.""" - self.writerows(record_batch.to_pylist()) + """Write records to a temporary file as CSV.""" + self._csv_writer.writerows(record_batch.to_pylist()) self.records_total += record_batch.num_rows self.records_since_last_flush += record_batch.num_rows -class ParquetBatchExportWriter(pq.ParquetWriter, BatchExportWriter): +class ParquetBatchExportWriter(BatchExportWriter): + """A `BatchExportWriter` for Apache Parquet format.""" + def __init__( self, max_bytes: int, @@ -617,31 +667,32 @@ def __init__( compression: str | None = "snappy", compression_level: int | None = None, ): - self.batch_export_file = BatchExportTemporaryFile(compression=None) # Handle compression in ParquetWriter - self.flush_callable = flush_callable - super().__init__( + batch_export_file=BatchExportTemporaryFile(compression=None), # Handle compression in ParquetWriter + max_bytes=max_bytes, + flush_callable=flush_callable, + ) + + self._parquet_writer = pq.ParquetWriter( self.batch_export_file, schema=schema, version=version, use_dictionary=use_dictionary, - compression=compression, + # Compression *can* be `None`. + compression=compression, # type: ignore compression_level=compression_level, ) - self.max_bytes = max_bytes - self.last_inserted_at = None - async def __aexit__(self, exc, value, tb): - """Close underlying Parquet writer to include footer bytes before flushing last.""" - self.writer.close() - self.is_open = False + """Close underlying Parquet writer to include Parquet footer bytes before flushing last.""" + self._parquet_writer.writer.close() + self._parquet_writer.is_open = False await super().__aexit__(exc, value, tb) def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: - """Write records to a temporary file as JSONL.""" - self.write_batch(record_batch.select(self.schema.names)) + """Write records to a temporary file as Parquet.""" + self._parquet_writer.write_batch(record_batch.select(self._parquet_writer.schema.names)) self.records_total += record_batch.num_rows self.records_since_last_flush += record_batch.num_rows From 539f335c32c00cdde4b23e7b9c4ed06d7898c507 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Wed, 20 Mar 2024 17:37:01 +0100 Subject: [PATCH 03/14] refactor: More clearly separate writer from temporary file We now should be more explicit about what is the context in which the batch export temporary file is alive. The writer outlives this context, so it can be used by callers to, for example, check how many records were written. --- .../temporal/batch_exports/batch_exports.py | 146 ++++++++++++------ .../temporal/batch_exports/s3_batch_export.py | 75 +++++---- .../test_s3_batch_export_workflow.py | 4 +- 3 files changed, 148 insertions(+), 77 deletions(-) diff --git a/posthog/temporal/batch_exports/batch_exports.py b/posthog/temporal/batch_exports/batch_exports.py index b75b0494d9209..87bc747fe5c71 100644 --- a/posthog/temporal/batch_exports/batch_exports.py +++ b/posthog/temporal/batch_exports/batch_exports.py @@ -1,5 +1,6 @@ import abc import collections.abc +import contextlib import csv import dataclasses import datetime as dt @@ -489,6 +490,13 @@ def reset(self): ] +class UnsupportedFileFormatError(Exception): + """Raised when a writer for an unsupported file format is requested.""" + + def __init__(self, file_format: str, destination: str): + super().__init__(f"{file_format} is not a supported format for {destination} batch exports.") + + class BatchExportWriter(abc.ABC): """A temporary file writer to be used by batch export workflows. @@ -516,16 +524,57 @@ class BatchExportWriter(abc.ABC): def __init__( self, - batch_export_file: BatchExportTemporaryFile, flush_callable: FlushCallable, max_bytes: int, + file_kwargs: collections.abc.Mapping[str, typing.Any], ): - self.batch_export_file = batch_export_file - self.max_bytes = max_bytes self.flush_callable = flush_callable + self.max_bytes = max_bytes + self.file_kwargs = file_kwargs + + self._batch_export_file = None + self.reset_writer_tracking() + + def reset_writer_tracking(self): + self.last_inserted_at: dt.datetime | None = None self.records_total = 0 self.records_since_last_flush = 0 - self.last_inserted_at: dt.datetime | None = None + self.bytes_total = 0 + self.bytes_since_last_flush = 0 + + @contextlib.asynccontextmanager + async def open_temporary_file(self): + """Explicitly open the temporary file this writer is writing to. + + The underlying `BatchExportTemporaryFile` is only accessible within this context manager. This helps + us separate the lifetime of the underlying temporary file from the writer: The writer may still be + accessed even after the temporary file is closed, while on the other hand we ensure the file and all + its data is flushed and not leaked outside the context. Any relevant tracking information + """ + self.reset_writer_tracking() + + with BatchExportTemporaryFile(**self.file_kwargs) as temp_file: + self._batch_export_file = temp_file + + try: + yield + finally: + self.track_bytes_written(temp_file) + + if self.last_inserted_at is not None and self.bytes_since_last_flush > 0: + # `bytes_since_last_flush` should be 0 unless: + # 1. The last batch wasn't flushed as it didn't reach `max_bytes`. + # 2. The last batch was flushed but there was another write after the last call to + # `write_record_batch`. For example, footer bytes. + await self.flush(self.last_inserted_at, is_last=True) + + self._batch_export_file = None + + @property + def batch_export_file(self): + if self._batch_export_file is None: + raise ValueError("Batch export file is closed. Did you forget to call 'open_temporary_file'?") + return self._batch_export_file @abc.abstractmethod def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: @@ -536,20 +585,13 @@ def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: """ pass - async def __aenter__(self): - """Context-manager protocol enter method.""" - self.batch_export_file.__enter__() - return self + def track_records_written(self, record_batch: pa.RecordBatch) -> None: + self.records_total += record_batch.num_rows + self.records_since_last_flush += record_batch.num_rows - async def __aexit__(self, exc, value, tb): - """Context-manager protocol exit method. - - We flush the latest data available in the file. Subclasses that implement formats that require - written footers should override this method to write the footers before the last flush. - """ - if self.last_inserted_at is not None and self.records_since_last_flush > 0: - await self.flush(self.last_inserted_at, is_last=True) - return self.batch_export_file.__exit__(exc, value, tb) + def track_bytes_written(self, batch_export_file: BatchExportTemporaryFile) -> None: + self.bytes_total = batch_export_file.bytes_total + self.bytes_since_last_flush = batch_export_file.bytes_since_last_reset async def write_record_batch(self, record_batch: pa.RecordBatch) -> None: """Issue a record batch write tracking progress and flushing if required.""" @@ -561,6 +603,8 @@ async def write_record_batch(self, record_batch: pa.RecordBatch) -> None: self._write_record_batch(record_batch.select(column_names)) self.last_inserted_at = last_inserted_at + self.track_records_written(record_batch) + self.track_bytes_written(self.batch_export_file) if self.bytes_since_last_flush >= self.max_bytes: await self.flush(last_inserted_at) @@ -575,12 +619,9 @@ async def flush(self, last_inserted_at: dt.datetime, is_last: bool = False) -> N is_last, ) self.batch_export_file.reset() - self.records_since_last_flush = 0 - @property - def bytes_since_last_flush(self) -> int: - """Access written bytes from underlying file for convenience.""" - return self.batch_export_file.bytes_since_last_reset + self.records_since_last_flush = 0 + self.bytes_since_last_flush = 0 class JSONLBatchExportWriter(BatchExportWriter): @@ -594,9 +635,9 @@ def __init__( default: typing.Callable = str, ): super().__init__( - batch_export_file=BatchExportTemporaryFile(compression=compression), max_bytes=max_bytes, flush_callable=flush_callable, + file_kwargs={"compression": compression}, ) self.default = default @@ -610,9 +651,6 @@ def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: for record in record_batch.to_pylist(): self.write(record) - self.records_total += 1 - self.records_since_last_flush += 1 - class CSVBatchExportWriter(BatchExportWriter): """A `BatchExportWriter` for CSV format.""" @@ -631,9 +669,9 @@ def __init__( compression: str | None = None, ): super().__init__( - batch_export_file=BatchExportTemporaryFile(compression=compression), max_bytes=max_bytes, flush_callable=flush_callable, + file_kwargs={"compression": compression}, ) self._csv_writer = csv.DictWriter( @@ -650,8 +688,6 @@ def __init__( def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: """Write records to a temporary file as CSV.""" self._csv_writer.writerows(record_batch.to_pylist()) - self.records_total += record_batch.num_rows - self.records_since_last_flush += record_batch.num_rows class ParquetBatchExportWriter(BatchExportWriter): @@ -663,38 +699,54 @@ def __init__( flush_callable: FlushCallable, schema: pa.Schema, version: str = "2.6", - use_dictionary: bool = True, compression: str | None = "snappy", compression_level: int | None = None, ): super().__init__( - batch_export_file=BatchExportTemporaryFile(compression=None), # Handle compression in ParquetWriter max_bytes=max_bytes, flush_callable=flush_callable, + file_kwargs={"compression": None}, # ParquetWriter handles compression ) + self.schema = schema + self.version = version + self.compression = compression + self.compression_level = compression_level - self._parquet_writer = pq.ParquetWriter( - self.batch_export_file, - schema=schema, - version=version, - use_dictionary=use_dictionary, - # Compression *can* be `None`. - compression=compression, # type: ignore - compression_level=compression_level, - ) + self._parquet_writer = None + + @property + def parquet_writer(self) -> pq.ParquetWriter: + if self._parquet_writer is None: + self._parquet_writer = pq.ParquetWriter( + self.batch_export_file, + schema=self.schema, + version=self.version, + # Compression *can* be `None`. + compression=self.compression, + compression_level=self.compression_level, + ) + return self._parquet_writer + + def ensure_parquet_writer_is_closed(self) -> None: + """Ensure ParquetWriter is closed as Parquet footer bytes are written on closing.""" + if self._parquet_writer is None: + return - async def __aexit__(self, exc, value, tb): - """Close underlying Parquet writer to include Parquet footer bytes before flushing last.""" self._parquet_writer.writer.close() - self._parquet_writer.is_open = False + self._parquet_writer = None - await super().__aexit__(exc, value, tb) + @contextlib.asynccontextmanager + async def open_temporary_file(self): + """Ensure underlying Parquet writer is closed before flushing and closing temporary file.""" + async with super().open_temporary_file(): + try: + yield + finally: + self.ensure_parquet_writer_is_closed() def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: """Write records to a temporary file as Parquet.""" - self._parquet_writer.write_batch(record_batch.select(self._parquet_writer.schema.names)) - self.records_total += record_batch.num_rows - self.records_since_last_flush += record_batch.num_rows + self.parquet_writer.write_batch(record_batch.select(self.parquet_writer.schema.names)) @dataclasses.dataclass diff --git a/posthog/temporal/batch_exports/s3_batch_export.py b/posthog/temporal/batch_exports/s3_batch_export.py index 13e1a60bba999..836785fe388b9 100644 --- a/posthog/temporal/batch_exports/s3_batch_export.py +++ b/posthog/temporal/batch_exports/s3_batch_export.py @@ -20,9 +20,12 @@ from posthog.temporal.batch_exports.base import PostHogWorkflow from posthog.temporal.batch_exports.batch_exports import ( BatchExportTemporaryFile, + BatchExportWriter, CreateBatchExportRunInputs, + FlushCallable, JSONLBatchExportWriter, ParquetBatchExportWriter, + UnsupportedFileFormatError, UpdateBatchExportRunStatusInputs, create_export_run, default_fields, @@ -512,35 +515,27 @@ async def flush_to_s3( last_uploaded_part_timestamp = str(last_inserted_at) activity.heartbeat(last_uploaded_part_timestamp, s3_upload.to_state()) - if inputs.file_format == "Parquet": - first_record_batch, record_iterator = peek_first_and_rewind(record_iterator) - first_record_batch = cast_record_batch_json_columns(first_record_batch) - column_names = first_record_batch.column_names - column_names.pop(column_names.index("_inserted_at")) - - schema = pa.schema( - # NOTE: For some reason, some batches set non-nullable fields as non-nullable, whereas other - # record batches have them as nullable. - # Until we figure it out, we set all fields to nullable. There are some fields we know - # are not nullable, but I'm opting for the more flexible option until we out why schemas differ - # between batches. - [field.with_nullable(True) for field in first_record_batch.select(column_names).schema] - ) - - writer = ParquetBatchExportWriter( - max_bytes=settings.BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES, - flush_callable=flush_to_s3, - compression=inputs.compression, - schema=schema, - ) - else: - writer = JSONLBatchExportWriter( - max_bytes=settings.BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES, - flush_callable=flush_to_s3, - compression=inputs.compression, - ) + first_record_batch, record_iterator = peek_first_and_rewind(record_iterator) + first_record_batch = cast_record_batch_json_columns(first_record_batch) + column_names = first_record_batch.column_names + column_names.pop(column_names.index("_inserted_at")) + + schema = pa.schema( + # NOTE: For some reason, some batches set non-nullable fields as non-nullable, whereas other + # record batches have them as nullable. + # Until we figure it out, we set all fields to nullable. There are some fields we know + # are not nullable, but I'm opting for the more flexible option until we out why schemas differ + # between batches. + [field.with_nullable(True) for field in first_record_batch.select(column_names).schema] + ) - async with writer: + writer = get_batch_export_writer( + inputs, + flush_callable=flush_to_s3, + max_bytes=settings.BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES, + schema=schema, + ) + async with writer.open_temporary_file(): rows_exported = get_rows_exported_metric() bytes_exported = get_bytes_exported_metric() @@ -551,7 +546,29 @@ async def flush_to_s3( await s3_upload.complete() - return local_results_file.records_total + return writer.records_total + + +def get_batch_export_writer( + inputs, flush_callable: FlushCallable, max_bytes: int, schema: pa.Schema | None = None +) -> BatchExportWriter: + if inputs.file_format == "Parquet": + writer = ParquetBatchExportWriter( + max_bytes=max_bytes, + flush_callable=flush_callable, + compression=inputs.compression, + schema=schema, + ) + elif inputs.file_format == "JSONLines": + writer = JSONLBatchExportWriter( + max_bytes=settings.BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES, + flush_callable=flush_callable, + compression=inputs.compression, + ) + else: + raise UnsupportedFileFormatError(inputs.file_format, "S3") + + return writer def cast_record_batch_json_columns( diff --git a/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py index 3b9ff1dd5c7ac..2c24f16926349 100644 --- a/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py @@ -382,7 +382,9 @@ async def test_insert_into_s3_activity_puts_data_into_s3( with override_settings( BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2 ): # 5MB, the minimum for Multipart uploads - await activity_environment.run(insert_into_s3_activity, insert_inputs) + records_total = await activity_environment.run(insert_into_s3_activity, insert_inputs) + + assert records_total == 10005 await assert_clickhouse_records_in_s3( s3_compatible_client=minio_client, From 80510fc283e4ea6924d87bd69cec8f1cebd60bb2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Wed, 20 Mar 2024 18:40:12 +0100 Subject: [PATCH 04/14] test: More parquet testing --- .../temporal/batch_exports/s3_batch_export.py | 5 ++- .../test_s3_batch_export_workflow.py | 39 ++++++++++++++++--- 2 files changed, 37 insertions(+), 7 deletions(-) diff --git a/posthog/temporal/batch_exports/s3_batch_export.py b/posthog/temporal/batch_exports/s3_batch_export.py index 836785fe388b9..24455d8716fb8 100644 --- a/posthog/temporal/batch_exports/s3_batch_export.py +++ b/posthog/temporal/batch_exports/s3_batch_export.py @@ -535,6 +535,7 @@ async def flush_to_s3( max_bytes=settings.BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES, schema=schema, ) + async with writer.open_temporary_file(): rows_exported = get_rows_exported_metric() bytes_exported = get_bytes_exported_metric() @@ -602,7 +603,7 @@ class JsonScalar(pa.ExtensionScalar): def as_py(self) -> dict | None: if self.value: - return orjson.loads(self.value.as_py()) + return orjson.loads(self.value.as_py().encode("utf-8")) else: return None @@ -611,7 +612,7 @@ class JsonType(pa.ExtensionType): """Type for JSON binary strings.""" def __init__(self): - super().__init__(pa.binary(), "json") + super().__init__(pa.string(), "json") def __arrow_ext_serialize__(self): return b"" diff --git a/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py index 2c24f16926349..8c700fb191f2b 100644 --- a/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py @@ -109,6 +109,15 @@ def s3_key_prefix(): return f"posthog-events-{str(uuid4())}" +@pytest.fixture +def file_format(request) -> str: + """S3 file format.""" + try: + return request.param + except AttributeError: + return f"JSONLines" + + async def delete_all_from_s3(minio_client, bucket_name: str, key_prefix: str): """Delete all objects in bucket_name under key_prefix.""" response = await minio_client.list_objects_v2(Bucket=bucket_name, Prefix=key_prefix) @@ -140,7 +149,7 @@ async def minio_client(bucket_name): await minio_client.delete_bucket(Bucket=bucket_name) -async def read_parquet_from_s3(bucket_name: str, key: str) -> list: +async def read_parquet_from_s3(bucket_name: str, key: str, json_columns) -> list: async with aioboto3.Session().client("sts") as sts: try: await sts.get_caller_identity() @@ -158,7 +167,19 @@ async def read_parquet_from_s3(bucket_name: str, key: str) -> list: parquet_data = [] for batch in table.to_batches(): - parquet_data.extend(batch.to_pylist()) + for record in batch.to_pylist(): + casted_record = {} + for k, v in record.items(): + if isinstance(v, dt.datetime): + # We read data from clickhouse as string, but parquet already casts them as dates. + # To facilitate comparison, we isoformat the dates. + casted_record[k] = v.isoformat() + elif k in json_columns and v is not None: + # Parquet doesn't have a variable map type, so JSON fields are just strings. + casted_record[k] = json.loads(v) + else: + casted_record[k] = v + parquet_data.append(casted_record) return parquet_data @@ -215,8 +236,10 @@ async def assert_clickhouse_records_in_s3( key = objects["Contents"][0].get("Key") assert key + json_columns = ("properties", "person_properties", "set", "set_once") + if file_format == "Parquet": - s3_data = await read_parquet_from_s3(bucket_name, key) + s3_data = await read_parquet_from_s3(bucket_name, key, json_columns) elif file_format == "JSONLines": s3_object = await s3_compatible_client.get_object(Bucket=bucket_name, Key=key) @@ -230,8 +253,6 @@ async def assert_clickhouse_records_in_s3( else: schema_column_names = [field["alias"] for field in s3_default_fields()] - json_columns = ("properties", "person_properties", "set", "set_once") - expected_records = [] for record_batch in iter_records( client=clickhouse_client, @@ -412,6 +433,7 @@ async def s3_batch_export( exclude_events, temporal_client, encryption, + file_format, ): destination_data = { "type": "S3", @@ -426,6 +448,7 @@ async def s3_batch_export( "exclude_events": exclude_events, "encryption": encryption, "kms_key_id": os.getenv("S3_TEST_KMS_KEY_ID") if encryption == "aws:kms" else None, + "file_format": file_format, }, } @@ -451,6 +474,7 @@ async def s3_batch_export( @pytest.mark.parametrize("compression", [None, "gzip", "brotli"], indirect=True) @pytest.mark.parametrize("exclude_events", [None, ["test-exclude"]], indirect=True) @pytest.mark.parametrize("batch_export_schema", TEST_S3_SCHEMAS) +@pytest.mark.parametrize("file_format", ["JSONLines", "Parquet"], indirect=True) async def test_s3_export_workflow_with_minio_bucket( clickhouse_client, minio_client, @@ -462,6 +486,7 @@ async def test_s3_export_workflow_with_minio_bucket( exclude_events, s3_key_prefix, batch_export_schema, + file_format, ): """Test S3BatchExport Workflow end-to-end by using a local MinIO bucket instead of S3. @@ -549,6 +574,7 @@ async def test_s3_export_workflow_with_minio_bucket( batch_export_schema=batch_export_schema, exclude_events=exclude_events, compression=compression, + file_format=file_format, ) @@ -578,6 +604,7 @@ async def s3_client(bucket_name, s3_key_prefix): @pytest.mark.parametrize("encryption", [None, "AES256", "aws:kms"], indirect=True) @pytest.mark.parametrize("bucket_name", [os.getenv("S3_TEST_BUCKET")], indirect=True) @pytest.mark.parametrize("batch_export_schema", TEST_S3_SCHEMAS) +@pytest.mark.parametrize("file_format", ["JSONLines", "Parquet"]) async def test_s3_export_workflow_with_s3_bucket( s3_client, clickhouse_client, @@ -590,6 +617,7 @@ async def test_s3_export_workflow_with_s3_bucket( exclude_events, ateam, batch_export_schema, + file_format, ): """Test S3 Export Workflow end-to-end by using an S3 bucket. @@ -687,6 +715,7 @@ async def test_s3_export_workflow_with_s3_bucket( exclude_events=exclude_events, include_events=None, compression=compression, + file_format=file_format, ) From 8cfce4c9ea05503c91781c108bc64c51bb1c3402 Mon Sep 17 00:00:00 2001 From: github-actions <41898282+github-actions[bot]@users.noreply.github.com> Date: Wed, 20 Mar 2024 17:58:08 +0000 Subject: [PATCH 05/14] Update query snapshots --- .../transforms/test/__snapshots__/test_in_cohort.ambr | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/posthog/hogql/transforms/test/__snapshots__/test_in_cohort.ambr b/posthog/hogql/transforms/test/__snapshots__/test_in_cohort.ambr index 9ff7f8ee0ab49..e0f5ea847110d 100644 --- a/posthog/hogql/transforms/test/__snapshots__/test_in_cohort.ambr +++ b/posthog/hogql/transforms/test/__snapshots__/test_in_cohort.ambr @@ -31,7 +31,7 @@ FROM events LEFT JOIN ( SELECT person_static_cohort.person_id AS cohort_person_id, 1 AS matched, person_static_cohort.cohort_id AS cohort_id FROM person_static_cohort - WHERE and(equals(person_static_cohort.team_id, 420), in(person_static_cohort.cohort_id, [12]))) AS __in_cohort ON equals(__in_cohort.cohort_person_id, events.person_id) + WHERE and(equals(person_static_cohort.team_id, 420), in(person_static_cohort.cohort_id, [11]))) AS __in_cohort ON equals(__in_cohort.cohort_person_id, events.person_id) WHERE and(equals(events.team_id, 420), 1, ifNull(equals(__in_cohort.matched, 1), 0)) LIMIT 100 SETTINGS readonly=2, max_execution_time=60, allow_experimental_object_type=1 @@ -42,7 +42,7 @@ FROM events LEFT JOIN ( SELECT person_id AS cohort_person_id, 1 AS matched, cohort_id FROM static_cohort_people - WHERE in(cohort_id, [12])) AS __in_cohort ON equals(__in_cohort.cohort_person_id, person_id) + WHERE in(cohort_id, [11])) AS __in_cohort ON equals(__in_cohort.cohort_person_id, person_id) WHERE and(1, equals(__in_cohort.matched, 1)) LIMIT 100 ''' @@ -55,7 +55,7 @@ FROM events LEFT JOIN ( SELECT person_static_cohort.person_id AS cohort_person_id, 1 AS matched, person_static_cohort.cohort_id AS cohort_id FROM person_static_cohort - WHERE and(equals(person_static_cohort.team_id, 420), in(person_static_cohort.cohort_id, [13]))) AS __in_cohort ON equals(__in_cohort.cohort_person_id, events.person_id) + WHERE and(equals(person_static_cohort.team_id, 420), in(person_static_cohort.cohort_id, [12]))) AS __in_cohort ON equals(__in_cohort.cohort_person_id, events.person_id) WHERE and(equals(events.team_id, 420), 1, ifNull(equals(__in_cohort.matched, 1), 0)) LIMIT 100 SETTINGS readonly=2, max_execution_time=60, allow_experimental_object_type=1 @@ -66,7 +66,7 @@ FROM events LEFT JOIN ( SELECT person_id AS cohort_person_id, 1 AS matched, cohort_id FROM static_cohort_people - WHERE in(cohort_id, [13])) AS __in_cohort ON equals(__in_cohort.cohort_person_id, person_id) + WHERE in(cohort_id, [12])) AS __in_cohort ON equals(__in_cohort.cohort_person_id, person_id) WHERE and(1, equals(__in_cohort.matched, 1)) LIMIT 100 ''' From 588626c024c84431126a0bacb88bb8e96299326f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Wed, 20 Mar 2024 19:10:32 +0100 Subject: [PATCH 06/14] fix: Typing --- .../temporal/batch_exports/batch_exports.py | 40 +++++++++++++------ .../temporal/batch_exports/s3_batch_export.py | 2 + 2 files changed, 29 insertions(+), 13 deletions(-) diff --git a/posthog/temporal/batch_exports/batch_exports.py b/posthog/temporal/batch_exports/batch_exports.py index 87bc747fe5c71..c99c926c7ebcd 100644 --- a/posthog/temporal/batch_exports/batch_exports.py +++ b/posthog/temporal/batch_exports/batch_exports.py @@ -532,7 +532,7 @@ def __init__( self.max_bytes = max_bytes self.file_kwargs = file_kwargs - self._batch_export_file = None + self._batch_export_file: BatchExportTemporaryFile | None = None self.reset_writer_tracking() def reset_writer_tracking(self): @@ -673,21 +673,35 @@ def __init__( flush_callable=flush_callable, file_kwargs={"compression": compression}, ) + self.field_names = field_names + self.extras_action: typing.Literal["raise", "ignore"] = extras_action + self.delimiter = delimiter + self.quote_char = quote_char + self.escape_char = escape_char + self.line_terminator = line_terminator + self.quoting = quoting - self._csv_writer = csv.DictWriter( - self.batch_export_file, - fieldnames=field_names, - extrasaction=extras_action, - delimiter=delimiter, - quotechar=quote_char, - escapechar=escape_char, - quoting=quoting, - lineterminator=line_terminator, - ) + self._csv_writer: csv.DictWriter | None = None + + @property + def csv_writer(self) -> csv.DictWriter: + if self._csv_writer is None: + self._csv_writer = csv.DictWriter( + self.batch_export_file, + fieldnames=self.field_names, + extrasaction=self.extras_action, + delimiter=self.delimiter, + quotechar=self.quote_char, + escapechar=self.escape_char, + quoting=self.quoting, + lineterminator=self.line_terminator, + ) + + return self._csv_writer def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: """Write records to a temporary file as CSV.""" - self._csv_writer.writerows(record_batch.to_pylist()) + self.csv_writer.writerows(record_batch.to_pylist()) class ParquetBatchExportWriter(BatchExportWriter): @@ -712,7 +726,7 @@ def __init__( self.compression = compression self.compression_level = compression_level - self._parquet_writer = None + self._parquet_writer: pq.ParquetWriter | None = None @property def parquet_writer(self) -> pq.ParquetWriter: diff --git a/posthog/temporal/batch_exports/s3_batch_export.py b/posthog/temporal/batch_exports/s3_batch_export.py index 24455d8716fb8..ef16340a6d0e6 100644 --- a/posthog/temporal/batch_exports/s3_batch_export.py +++ b/posthog/temporal/batch_exports/s3_batch_export.py @@ -553,6 +553,8 @@ async def flush_to_s3( def get_batch_export_writer( inputs, flush_callable: FlushCallable, max_bytes: int, schema: pa.Schema | None = None ) -> BatchExportWriter: + writer: BatchExportWriter + if inputs.file_format == "Parquet": writer = ParquetBatchExportWriter( max_bytes=max_bytes, From d51400b7ba0b8516d6cabb12449ed138f71f56ef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Wed, 20 Mar 2024 19:26:26 +0100 Subject: [PATCH 07/14] refactor: Move temporary file to new module --- .../temporal/batch_exports/batch_exports.py | 482 ----------------- .../batch_exports/bigquery_batch_export.py | 4 +- .../batch_exports/http_batch_export.py | 6 +- .../batch_exports/postgres_batch_export.py | 4 +- .../temporal/batch_exports/s3_batch_export.py | 16 +- .../batch_exports/snowflake_batch_export.py | 4 +- .../temporal/batch_exports/temporary_file.py | 488 ++++++++++++++++++ .../tests/batch_exports/test_batch_exports.py | 182 ------- .../batch_exports/test_temporary_file.py | 188 +++++++ 9 files changed, 698 insertions(+), 676 deletions(-) create mode 100644 posthog/temporal/batch_exports/temporary_file.py create mode 100644 posthog/temporal/tests/batch_exports/test_temporary_file.py diff --git a/posthog/temporal/batch_exports/batch_exports.py b/posthog/temporal/batch_exports/batch_exports.py index c99c926c7ebcd..fd8aedb544142 100644 --- a/posthog/temporal/batch_exports/batch_exports.py +++ b/posthog/temporal/batch_exports/batch_exports.py @@ -1,19 +1,11 @@ -import abc import collections.abc -import contextlib -import csv import dataclasses import datetime as dt -import gzip -import tempfile import typing import uuid from string import Template -import brotli -import orjson import pyarrow as pa -import pyarrow.parquet as pq from asgiref.sync import sync_to_async from django.conf import settings from temporalio import activity, exceptions, workflow @@ -289,480 +281,6 @@ def get_data_interval(interval: str, data_interval_end: str | None) -> tuple[dt. return (data_interval_start_dt, data_interval_end_dt) -def json_dumps_bytes(d) -> bytes: - return orjson.dumps(d, default=str) - - -class BatchExportTemporaryFile: - """A TemporaryFile used to as an intermediate step while exporting data. - - This class does not implement the file-like interface but rather passes any calls - to the underlying tempfile.NamedTemporaryFile. We do override 'write' methods - to allow tracking bytes and records. - """ - - def __init__( - self, - mode: str = "w+b", - buffering=-1, - compression: str | None = None, - encoding: str | None = None, - newline: str | None = None, - suffix: str | None = None, - prefix: str | None = None, - dir: str | None = None, - *, - errors: str | None = None, - ): - self._file = tempfile.NamedTemporaryFile( - mode=mode, - encoding=encoding, - newline=newline, - buffering=buffering, - suffix=suffix, - prefix=prefix, - dir=dir, - errors=errors, - ) - self.compression = compression - self.bytes_total = 0 - self.records_total = 0 - self.bytes_since_last_reset = 0 - self.records_since_last_reset = 0 - self._brotli_compressor = None - - def __getattr__(self, name): - """Pass get attr to underlying tempfile.NamedTemporaryFile.""" - return self._file.__getattr__(name) - - def __enter__(self): - """Context-manager protocol enter method.""" - self._file.__enter__() - return self - - def __exit__(self, exc, value, tb): - """Context-manager protocol exit method.""" - return self._file.__exit__(exc, value, tb) - - def __iter__(self): - yield from self._file - - @property - def brotli_compressor(self): - if self._brotli_compressor is None: - self._brotli_compressor = brotli.Compressor() - return self._brotli_compressor - - def compress(self, content: bytes | str) -> bytes: - if isinstance(content, str): - encoded = content.encode("utf-8") - else: - encoded = content - - match self.compression: - case "gzip": - return gzip.compress(encoded) - case "brotli": - self.brotli_compressor.process(encoded) - return self.brotli_compressor.flush() - case None: - return encoded - case _: - raise ValueError(f"Unsupported compression: '{self.compression}'") - - def write(self, content: bytes | str): - """Write bytes to underlying file keeping track of how many bytes were written.""" - compressed_content = self.compress(content) - - if "b" in self.mode: - result = self._file.write(compressed_content) - else: - result = self._file.write(compressed_content.decode("utf-8")) - - self.bytes_total += result - self.bytes_since_last_reset += result - - return result - - def write_record_as_bytes(self, record: bytes): - result = self.write(record) - - self.records_total += 1 - self.records_since_last_reset += 1 - - return result - - def write_records_to_jsonl(self, records): - """Write records to a temporary file as JSONL.""" - if len(records) == 1: - jsonl_dump = orjson.dumps(records[0], option=orjson.OPT_APPEND_NEWLINE, default=str) - else: - jsonl_dump = b"\n".join(map(json_dumps_bytes, records)) - - result = self.write(jsonl_dump) - - self.records_total += len(records) - self.records_since_last_reset += len(records) - - return result - - def write_records_to_csv( - self, - records, - fieldnames: None | collections.abc.Sequence[str] = None, - extrasaction: typing.Literal["raise", "ignore"] = "ignore", - delimiter: str = ",", - quotechar: str = '"', - escapechar: str | None = "\\", - lineterminator: str = "\n", - quoting=csv.QUOTE_NONE, - ): - """Write records to a temporary file as CSV.""" - if len(records) == 0: - return - - if fieldnames is None: - fieldnames = list(records[0].keys()) - - writer = csv.DictWriter( - self, - fieldnames=fieldnames, - extrasaction=extrasaction, - delimiter=delimiter, - quotechar=quotechar, - escapechar=escapechar, - quoting=quoting, - lineterminator=lineterminator, - ) - writer.writerows(records) - - self.records_total += len(records) - self.records_since_last_reset += len(records) - - def write_records_to_tsv( - self, - records, - fieldnames: None | list[str] = None, - extrasaction: typing.Literal["raise", "ignore"] = "ignore", - quotechar: str = '"', - escapechar: str | None = "\\", - lineterminator: str = "\n", - quoting=csv.QUOTE_NONE, - ): - """Write records to a temporary file as TSV.""" - return self.write_records_to_csv( - records, - fieldnames=fieldnames, - extrasaction=extrasaction, - delimiter="\t", - quotechar=quotechar, - escapechar=escapechar, - quoting=quoting, - lineterminator=lineterminator, - ) - - def rewind(self): - """Rewind the file before reading it.""" - if self.compression == "brotli": - result = self._file.write(self.brotli_compressor.finish()) - - self.bytes_total += result - self.bytes_since_last_reset += result - - self._brotli_compressor = None - - self._file.seek(0) - - def reset(self): - """Reset underlying file by truncating it. - - Also resets the tracker attributes for bytes and records since last reset. - """ - self._file.seek(0) - self._file.truncate() - - self.bytes_since_last_reset = 0 - self.records_since_last_reset = 0 - - -FlushCallable = collections.abc.Callable[ - [BatchExportTemporaryFile, int, int, dt.datetime, bool], collections.abc.Awaitable[None] -] - - -class UnsupportedFileFormatError(Exception): - """Raised when a writer for an unsupported file format is requested.""" - - def __init__(self, file_format: str, destination: str): - super().__init__(f"{file_format} is not a supported format for {destination} batch exports.") - - -class BatchExportWriter(abc.ABC): - """A temporary file writer to be used by batch export workflows. - - Subclasses should define `_write_record_batch` with the particular intricacies - of the format they are writing as. - - Actual writing calls are passed to the underlying `batch_export_file`. - - Attributes: - batch_export_file: The temporary file we are writing to. - bytes_flush_threshold: Flush the temporary file with the provided `flush_callable` - upon reaching or surpassing this threshold. Keep in mind we write on a RecordBatch - per RecordBatch basis, which means the threshold will be surpassed by at most the - size of a RecordBatch before a flush occurs. - flush_callable: A callback to flush the temporary file when `bytes_flush_treshold` is reached. - The temporary file will be reset after calling `flush_callable`. - records_total: The total number of records (not RecordBatches!) written. - records_since_last_flush: The number of records written since last flush. - last_inserted_at: Latest `_inserted_at` written. This attribute leaks some implementation - details, as we are making two assumptions about the RecordBatches being written: - * We assume RecordBatches are sorted on `_inserted_at`, which currently happens with - an `ORDER BY` clause. - * We assume `_inserted_at` is present, as it's added to all batch export queries. - """ - - def __init__( - self, - flush_callable: FlushCallable, - max_bytes: int, - file_kwargs: collections.abc.Mapping[str, typing.Any], - ): - self.flush_callable = flush_callable - self.max_bytes = max_bytes - self.file_kwargs = file_kwargs - - self._batch_export_file: BatchExportTemporaryFile | None = None - self.reset_writer_tracking() - - def reset_writer_tracking(self): - self.last_inserted_at: dt.datetime | None = None - self.records_total = 0 - self.records_since_last_flush = 0 - self.bytes_total = 0 - self.bytes_since_last_flush = 0 - - @contextlib.asynccontextmanager - async def open_temporary_file(self): - """Explicitly open the temporary file this writer is writing to. - - The underlying `BatchExportTemporaryFile` is only accessible within this context manager. This helps - us separate the lifetime of the underlying temporary file from the writer: The writer may still be - accessed even after the temporary file is closed, while on the other hand we ensure the file and all - its data is flushed and not leaked outside the context. Any relevant tracking information - """ - self.reset_writer_tracking() - - with BatchExportTemporaryFile(**self.file_kwargs) as temp_file: - self._batch_export_file = temp_file - - try: - yield - finally: - self.track_bytes_written(temp_file) - - if self.last_inserted_at is not None and self.bytes_since_last_flush > 0: - # `bytes_since_last_flush` should be 0 unless: - # 1. The last batch wasn't flushed as it didn't reach `max_bytes`. - # 2. The last batch was flushed but there was another write after the last call to - # `write_record_batch`. For example, footer bytes. - await self.flush(self.last_inserted_at, is_last=True) - - self._batch_export_file = None - - @property - def batch_export_file(self): - if self._batch_export_file is None: - raise ValueError("Batch export file is closed. Did you forget to call 'open_temporary_file'?") - return self._batch_export_file - - @abc.abstractmethod - def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: - """Write a record batch to the underlying `BatchExportTemporaryFile`. - - Subclasses must override this to provide the actual implementation according to the supported - file format. - """ - pass - - def track_records_written(self, record_batch: pa.RecordBatch) -> None: - self.records_total += record_batch.num_rows - self.records_since_last_flush += record_batch.num_rows - - def track_bytes_written(self, batch_export_file: BatchExportTemporaryFile) -> None: - self.bytes_total = batch_export_file.bytes_total - self.bytes_since_last_flush = batch_export_file.bytes_since_last_reset - - async def write_record_batch(self, record_batch: pa.RecordBatch) -> None: - """Issue a record batch write tracking progress and flushing if required.""" - last_inserted_at = record_batch.column("_inserted_at")[0].as_py() - - column_names = record_batch.column_names - column_names.pop(column_names.index("_inserted_at")) - - self._write_record_batch(record_batch.select(column_names)) - - self.last_inserted_at = last_inserted_at - self.track_records_written(record_batch) - self.track_bytes_written(self.batch_export_file) - - if self.bytes_since_last_flush >= self.max_bytes: - await self.flush(last_inserted_at) - - async def flush(self, last_inserted_at: dt.datetime, is_last: bool = False) -> None: - """Call the provided `flush_callable` and reset underlying file.""" - await self.flush_callable( - self.batch_export_file, - self.records_since_last_flush, - self.bytes_since_last_flush, - last_inserted_at, - is_last, - ) - self.batch_export_file.reset() - - self.records_since_last_flush = 0 - self.bytes_since_last_flush = 0 - - -class JSONLBatchExportWriter(BatchExportWriter): - """A `BatchExportWriter` for JSONLines format.""" - - def __init__( - self, - max_bytes: int, - flush_callable: FlushCallable, - compression: None | str = None, - default: typing.Callable = str, - ): - super().__init__( - max_bytes=max_bytes, - flush_callable=flush_callable, - file_kwargs={"compression": compression}, - ) - - self.default = default - - def write(self, content: bytes) -> int: - n = self.batch_export_file.write(orjson.dumps(content, default=str) + b"\n") - return n - - def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: - """Write records to a temporary file as JSONL.""" - for record in record_batch.to_pylist(): - self.write(record) - - -class CSVBatchExportWriter(BatchExportWriter): - """A `BatchExportWriter` for CSV format.""" - - def __init__( - self, - max_bytes: int, - flush_callable: FlushCallable, - field_names: collections.abc.Sequence[str], - extras_action: typing.Literal["raise", "ignore"] = "ignore", - delimiter: str = ",", - quote_char: str = '"', - escape_char: str | None = "\\", - line_terminator: str = "\n", - quoting=csv.QUOTE_NONE, - compression: str | None = None, - ): - super().__init__( - max_bytes=max_bytes, - flush_callable=flush_callable, - file_kwargs={"compression": compression}, - ) - self.field_names = field_names - self.extras_action: typing.Literal["raise", "ignore"] = extras_action - self.delimiter = delimiter - self.quote_char = quote_char - self.escape_char = escape_char - self.line_terminator = line_terminator - self.quoting = quoting - - self._csv_writer: csv.DictWriter | None = None - - @property - def csv_writer(self) -> csv.DictWriter: - if self._csv_writer is None: - self._csv_writer = csv.DictWriter( - self.batch_export_file, - fieldnames=self.field_names, - extrasaction=self.extras_action, - delimiter=self.delimiter, - quotechar=self.quote_char, - escapechar=self.escape_char, - quoting=self.quoting, - lineterminator=self.line_terminator, - ) - - return self._csv_writer - - def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: - """Write records to a temporary file as CSV.""" - self.csv_writer.writerows(record_batch.to_pylist()) - - -class ParquetBatchExportWriter(BatchExportWriter): - """A `BatchExportWriter` for Apache Parquet format.""" - - def __init__( - self, - max_bytes: int, - flush_callable: FlushCallable, - schema: pa.Schema, - version: str = "2.6", - compression: str | None = "snappy", - compression_level: int | None = None, - ): - super().__init__( - max_bytes=max_bytes, - flush_callable=flush_callable, - file_kwargs={"compression": None}, # ParquetWriter handles compression - ) - self.schema = schema - self.version = version - self.compression = compression - self.compression_level = compression_level - - self._parquet_writer: pq.ParquetWriter | None = None - - @property - def parquet_writer(self) -> pq.ParquetWriter: - if self._parquet_writer is None: - self._parquet_writer = pq.ParquetWriter( - self.batch_export_file, - schema=self.schema, - version=self.version, - # Compression *can* be `None`. - compression=self.compression, - compression_level=self.compression_level, - ) - return self._parquet_writer - - def ensure_parquet_writer_is_closed(self) -> None: - """Ensure ParquetWriter is closed as Parquet footer bytes are written on closing.""" - if self._parquet_writer is None: - return - - self._parquet_writer.writer.close() - self._parquet_writer = None - - @contextlib.asynccontextmanager - async def open_temporary_file(self): - """Ensure underlying Parquet writer is closed before flushing and closing temporary file.""" - async with super().open_temporary_file(): - try: - yield - finally: - self.ensure_parquet_writer_is_closed() - - def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: - """Write records to a temporary file as Parquet.""" - self.parquet_writer.write_batch(record_batch.select(self.parquet_writer.schema.names)) - - @dataclasses.dataclass class CreateBatchExportRunInputs: """Inputs to the create_export_run activity. diff --git a/posthog/temporal/batch_exports/bigquery_batch_export.py b/posthog/temporal/batch_exports/bigquery_batch_export.py index a0469de79bb9e..b754a7add16b4 100644 --- a/posthog/temporal/batch_exports/bigquery_batch_export.py +++ b/posthog/temporal/batch_exports/bigquery_batch_export.py @@ -15,7 +15,6 @@ from posthog.batch_exports.service import BatchExportField, BatchExportSchema, BigQueryBatchExportInputs from posthog.temporal.batch_exports.base import PostHogWorkflow from posthog.temporal.batch_exports.batch_exports import ( - BatchExportTemporaryFile, CreateBatchExportRunInputs, UpdateBatchExportRunStatusInputs, create_export_run, @@ -29,6 +28,9 @@ get_bytes_exported_metric, get_rows_exported_metric, ) +from posthog.temporal.batch_exports.temporary_file import ( + BatchExportTemporaryFile, +) from posthog.temporal.batch_exports.utils import peek_first_and_rewind from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.logger import bind_temporal_worker_logger diff --git a/posthog/temporal/batch_exports/http_batch_export.py b/posthog/temporal/batch_exports/http_batch_export.py index 8aca65c80ff38..2866d50c99876 100644 --- a/posthog/temporal/batch_exports/http_batch_export.py +++ b/posthog/temporal/batch_exports/http_batch_export.py @@ -13,7 +13,6 @@ from posthog.models import BatchExportRun from posthog.temporal.batch_exports.base import PostHogWorkflow from posthog.temporal.batch_exports.batch_exports import ( - BatchExportTemporaryFile, CreateBatchExportRunInputs, UpdateBatchExportRunStatusInputs, create_export_run, @@ -21,12 +20,15 @@ get_data_interval, get_rows_count, iter_records, - json_dumps_bytes, ) from posthog.temporal.batch_exports.metrics import ( get_bytes_exported_metric, get_rows_exported_metric, ) +from posthog.temporal.batch_exports.temporary_file import ( + BatchExportTemporaryFile, + json_dumps_bytes, +) from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.logger import bind_temporal_worker_logger diff --git a/posthog/temporal/batch_exports/postgres_batch_export.py b/posthog/temporal/batch_exports/postgres_batch_export.py index 5dbfc6faa4acf..98969ee78de79 100644 --- a/posthog/temporal/batch_exports/postgres_batch_export.py +++ b/posthog/temporal/batch_exports/postgres_batch_export.py @@ -17,7 +17,6 @@ from posthog.batch_exports.service import BatchExportField, BatchExportSchema, PostgresBatchExportInputs from posthog.temporal.batch_exports.base import PostHogWorkflow from posthog.temporal.batch_exports.batch_exports import ( - BatchExportTemporaryFile, CreateBatchExportRunInputs, UpdateBatchExportRunStatusInputs, create_export_run, @@ -31,6 +30,9 @@ get_bytes_exported_metric, get_rows_exported_metric, ) +from posthog.temporal.batch_exports.temporary_file import ( + BatchExportTemporaryFile, +) from posthog.temporal.batch_exports.utils import peek_first_and_rewind from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.logger import bind_temporal_worker_logger diff --git a/posthog/temporal/batch_exports/s3_batch_export.py b/posthog/temporal/batch_exports/s3_batch_export.py index ef16340a6d0e6..70fa3e3e71490 100644 --- a/posthog/temporal/batch_exports/s3_batch_export.py +++ b/posthog/temporal/batch_exports/s3_batch_export.py @@ -19,13 +19,7 @@ from posthog.batch_exports.service import BatchExportField, BatchExportSchema, S3BatchExportInputs from posthog.temporal.batch_exports.base import PostHogWorkflow from posthog.temporal.batch_exports.batch_exports import ( - BatchExportTemporaryFile, - BatchExportWriter, CreateBatchExportRunInputs, - FlushCallable, - JSONLBatchExportWriter, - ParquetBatchExportWriter, - UnsupportedFileFormatError, UpdateBatchExportRunStatusInputs, create_export_run, default_fields, @@ -38,6 +32,14 @@ get_bytes_exported_metric, get_rows_exported_metric, ) +from posthog.temporal.batch_exports.temporary_file import ( + BatchExportTemporaryFile, + BatchExportWriter, + FlushCallable, + JSONLBatchExportWriter, + ParquetBatchExportWriter, + UnsupportedFileFormatError, +) from posthog.temporal.batch_exports.utils import peek_first_and_rewind from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.logger import bind_temporal_worker_logger @@ -474,7 +476,7 @@ async def insert_into_s3_activity(inputs: S3InsertInputs) -> int: last_uploaded_part_timestamp: str | None = None - async def worker_shutdown_handler(): + async def worker_shutdown_handler() -> None: """Handle the Worker shutting down by heart-beating our latest status.""" await activity.wait_for_worker_shutdown() logger.warn( diff --git a/posthog/temporal/batch_exports/snowflake_batch_export.py b/posthog/temporal/batch_exports/snowflake_batch_export.py index be94eca89a799..9053f3e1006ad 100644 --- a/posthog/temporal/batch_exports/snowflake_batch_export.py +++ b/posthog/temporal/batch_exports/snowflake_batch_export.py @@ -18,7 +18,6 @@ from posthog.batch_exports.service import BatchExportField, BatchExportSchema, SnowflakeBatchExportInputs from posthog.temporal.batch_exports.base import PostHogWorkflow from posthog.temporal.batch_exports.batch_exports import ( - BatchExportTemporaryFile, CreateBatchExportRunInputs, UpdateBatchExportRunStatusInputs, create_export_run, @@ -32,6 +31,9 @@ get_bytes_exported_metric, get_rows_exported_metric, ) +from posthog.temporal.batch_exports.temporary_file import ( + BatchExportTemporaryFile, +) from posthog.temporal.batch_exports.utils import peek_first_and_rewind from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.logger import bind_temporal_worker_logger diff --git a/posthog/temporal/batch_exports/temporary_file.py b/posthog/temporal/batch_exports/temporary_file.py new file mode 100644 index 0000000000000..40dfc72bc64c9 --- /dev/null +++ b/posthog/temporal/batch_exports/temporary_file.py @@ -0,0 +1,488 @@ +"""This module contains a temporary file to stage data in batch exports.""" +import abc +import collections.abc +import contextlib +import csv +import datetime as dt +import gzip +import tempfile +import typing + +import brotli +import orjson +import pyarrow as pa +import pyarrow.parquet as pq + + +def json_dumps_bytes(d) -> bytes: + return orjson.dumps(d, default=str) + + +class BatchExportTemporaryFile: + """A TemporaryFile used to as an intermediate step while exporting data. + + This class does not implement the file-like interface but rather passes any calls + to the underlying tempfile.NamedTemporaryFile. We do override 'write' methods + to allow tracking bytes and records. + """ + + def __init__( + self, + mode: str = "w+b", + buffering=-1, + compression: str | None = None, + encoding: str | None = None, + newline: str | None = None, + suffix: str | None = None, + prefix: str | None = None, + dir: str | None = None, + *, + errors: str | None = None, + ): + self._file = tempfile.NamedTemporaryFile( + mode=mode, + encoding=encoding, + newline=newline, + buffering=buffering, + suffix=suffix, + prefix=prefix, + dir=dir, + errors=errors, + ) + self.compression = compression + self.bytes_total = 0 + self.records_total = 0 + self.bytes_since_last_reset = 0 + self.records_since_last_reset = 0 + self._brotli_compressor = None + + def __getattr__(self, name): + """Pass get attr to underlying tempfile.NamedTemporaryFile.""" + return self._file.__getattr__(name) + + def __enter__(self): + """Context-manager protocol enter method.""" + self._file.__enter__() + return self + + def __exit__(self, exc, value, tb): + """Context-manager protocol exit method.""" + return self._file.__exit__(exc, value, tb) + + def __iter__(self): + yield from self._file + + @property + def brotli_compressor(self): + if self._brotli_compressor is None: + self._brotli_compressor = brotli.Compressor() + return self._brotli_compressor + + def compress(self, content: bytes | str) -> bytes: + if isinstance(content, str): + encoded = content.encode("utf-8") + else: + encoded = content + + match self.compression: + case "gzip": + return gzip.compress(encoded) + case "brotli": + self.brotli_compressor.process(encoded) + return self.brotli_compressor.flush() + case None: + return encoded + case _: + raise ValueError(f"Unsupported compression: '{self.compression}'") + + def write(self, content: bytes | str): + """Write bytes to underlying file keeping track of how many bytes were written.""" + compressed_content = self.compress(content) + + if "b" in self.mode: + result = self._file.write(compressed_content) + else: + result = self._file.write(compressed_content.decode("utf-8")) + + self.bytes_total += result + self.bytes_since_last_reset += result + + return result + + def write_record_as_bytes(self, record: bytes): + result = self.write(record) + + self.records_total += 1 + self.records_since_last_reset += 1 + + return result + + def write_records_to_jsonl(self, records): + """Write records to a temporary file as JSONL.""" + if len(records) == 1: + jsonl_dump = orjson.dumps(records[0], option=orjson.OPT_APPEND_NEWLINE, default=str) + else: + jsonl_dump = b"\n".join(map(json_dumps_bytes, records)) + + result = self.write(jsonl_dump) + + self.records_total += len(records) + self.records_since_last_reset += len(records) + + return result + + def write_records_to_csv( + self, + records, + fieldnames: None | collections.abc.Sequence[str] = None, + extrasaction: typing.Literal["raise", "ignore"] = "ignore", + delimiter: str = ",", + quotechar: str = '"', + escapechar: str | None = "\\", + lineterminator: str = "\n", + quoting=csv.QUOTE_NONE, + ): + """Write records to a temporary file as CSV.""" + if len(records) == 0: + return + + if fieldnames is None: + fieldnames = list(records[0].keys()) + + writer = csv.DictWriter( + self, + fieldnames=fieldnames, + extrasaction=extrasaction, + delimiter=delimiter, + quotechar=quotechar, + escapechar=escapechar, + quoting=quoting, + lineterminator=lineterminator, + ) + writer.writerows(records) + + self.records_total += len(records) + self.records_since_last_reset += len(records) + + def write_records_to_tsv( + self, + records, + fieldnames: None | list[str] = None, + extrasaction: typing.Literal["raise", "ignore"] = "ignore", + quotechar: str = '"', + escapechar: str | None = "\\", + lineterminator: str = "\n", + quoting=csv.QUOTE_NONE, + ): + """Write records to a temporary file as TSV.""" + return self.write_records_to_csv( + records, + fieldnames=fieldnames, + extrasaction=extrasaction, + delimiter="\t", + quotechar=quotechar, + escapechar=escapechar, + quoting=quoting, + lineterminator=lineterminator, + ) + + def rewind(self): + """Rewind the file before reading it.""" + if self.compression == "brotli": + result = self._file.write(self.brotli_compressor.finish()) + + self.bytes_total += result + self.bytes_since_last_reset += result + + self._brotli_compressor = None + + self._file.seek(0) + + def reset(self): + """Reset underlying file by truncating it. + + Also resets the tracker attributes for bytes and records since last reset. + """ + self._file.seek(0) + self._file.truncate() + + self.bytes_since_last_reset = 0 + self.records_since_last_reset = 0 + + +FlushCallable = collections.abc.Callable[ + [BatchExportTemporaryFile, int, int, dt.datetime, bool], collections.abc.Awaitable[None] +] + + +class UnsupportedFileFormatError(Exception): + """Raised when a writer for an unsupported file format is requested.""" + + def __init__(self, file_format: str, destination: str): + super().__init__(f"{file_format} is not a supported format for {destination} batch exports.") + + +class BatchExportWriter(abc.ABC): + """A temporary file writer to be used by batch export workflows. + + Subclasses should define `_write_record_batch` with the particular intricacies + of the format they are writing as. + + Actual writing calls are passed to the underlying `batch_export_file`. + + Attributes: + batch_export_file: The temporary file we are writing to. + bytes_flush_threshold: Flush the temporary file with the provided `flush_callable` + upon reaching or surpassing this threshold. Keep in mind we write on a RecordBatch + per RecordBatch basis, which means the threshold will be surpassed by at most the + size of a RecordBatch before a flush occurs. + flush_callable: A callback to flush the temporary file when `bytes_flush_treshold` is reached. + The temporary file will be reset after calling `flush_callable`. + records_total: The total number of records (not RecordBatches!) written. + records_since_last_flush: The number of records written since last flush. + last_inserted_at: Latest `_inserted_at` written. This attribute leaks some implementation + details, as we are making two assumptions about the RecordBatches being written: + * We assume RecordBatches are sorted on `_inserted_at`, which currently happens with + an `ORDER BY` clause. + * We assume `_inserted_at` is present, as it's added to all batch export queries. + """ + + def __init__( + self, + flush_callable: FlushCallable, + max_bytes: int, + file_kwargs: collections.abc.Mapping[str, typing.Any], + ): + self.flush_callable = flush_callable + self.max_bytes = max_bytes + self.file_kwargs = file_kwargs + + self._batch_export_file: BatchExportTemporaryFile | None = None + self.reset_writer_tracking() + + def reset_writer_tracking(self): + self.last_inserted_at: dt.datetime | None = None + self.records_total = 0 + self.records_since_last_flush = 0 + self.bytes_total = 0 + self.bytes_since_last_flush = 0 + + @contextlib.asynccontextmanager + async def open_temporary_file(self): + """Explicitly open the temporary file this writer is writing to. + + The underlying `BatchExportTemporaryFile` is only accessible within this context manager. This helps + us separate the lifetime of the underlying temporary file from the writer: The writer may still be + accessed even after the temporary file is closed, while on the other hand we ensure the file and all + its data is flushed and not leaked outside the context. Any relevant tracking information + """ + self.reset_writer_tracking() + + with BatchExportTemporaryFile(**self.file_kwargs) as temp_file: + self._batch_export_file = temp_file + + try: + yield + finally: + self.track_bytes_written(temp_file) + + if self.last_inserted_at is not None and self.bytes_since_last_flush > 0: + # `bytes_since_last_flush` should be 0 unless: + # 1. The last batch wasn't flushed as it didn't reach `max_bytes`. + # 2. The last batch was flushed but there was another write after the last call to + # `write_record_batch`. For example, footer bytes. + await self.flush(self.last_inserted_at, is_last=True) + + self._batch_export_file = None + + @property + def batch_export_file(self): + if self._batch_export_file is None: + raise ValueError("Batch export file is closed. Did you forget to call 'open_temporary_file'?") + return self._batch_export_file + + @abc.abstractmethod + def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: + """Write a record batch to the underlying `BatchExportTemporaryFile`. + + Subclasses must override this to provide the actual implementation according to the supported + file format. + """ + pass + + def track_records_written(self, record_batch: pa.RecordBatch) -> None: + self.records_total += record_batch.num_rows + self.records_since_last_flush += record_batch.num_rows + + def track_bytes_written(self, batch_export_file: BatchExportTemporaryFile) -> None: + self.bytes_total = batch_export_file.bytes_total + self.bytes_since_last_flush = batch_export_file.bytes_since_last_reset + + async def write_record_batch(self, record_batch: pa.RecordBatch) -> None: + """Issue a record batch write tracking progress and flushing if required.""" + last_inserted_at = record_batch.column("_inserted_at")[0].as_py() + + column_names = record_batch.column_names + column_names.pop(column_names.index("_inserted_at")) + + self._write_record_batch(record_batch.select(column_names)) + + self.last_inserted_at = last_inserted_at + self.track_records_written(record_batch) + self.track_bytes_written(self.batch_export_file) + + if self.bytes_since_last_flush >= self.max_bytes: + await self.flush(last_inserted_at) + + async def flush(self, last_inserted_at: dt.datetime, is_last: bool = False) -> None: + """Call the provided `flush_callable` and reset underlying file.""" + await self.flush_callable( + self.batch_export_file, + self.records_since_last_flush, + self.bytes_since_last_flush, + last_inserted_at, + is_last, + ) + self.batch_export_file.reset() + + self.records_since_last_flush = 0 + self.bytes_since_last_flush = 0 + + +class JSONLBatchExportWriter(BatchExportWriter): + """A `BatchExportWriter` for JSONLines format.""" + + def __init__( + self, + max_bytes: int, + flush_callable: FlushCallable, + compression: None | str = None, + default: typing.Callable = str, + ): + super().__init__( + max_bytes=max_bytes, + flush_callable=flush_callable, + file_kwargs={"compression": compression}, + ) + + self.default = default + + def write(self, content: bytes) -> int: + n = self.batch_export_file.write(orjson.dumps(content, default=str) + b"\n") + return n + + def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: + """Write records to a temporary file as JSONL.""" + for record in record_batch.to_pylist(): + self.write(record) + + +class CSVBatchExportWriter(BatchExportWriter): + """A `BatchExportWriter` for CSV format.""" + + def __init__( + self, + max_bytes: int, + flush_callable: FlushCallable, + field_names: collections.abc.Sequence[str], + extras_action: typing.Literal["raise", "ignore"] = "ignore", + delimiter: str = ",", + quote_char: str = '"', + escape_char: str | None = "\\", + line_terminator: str = "\n", + quoting=csv.QUOTE_NONE, + compression: str | None = None, + ): + super().__init__( + max_bytes=max_bytes, + flush_callable=flush_callable, + file_kwargs={"compression": compression}, + ) + self.field_names = field_names + self.extras_action: typing.Literal["raise", "ignore"] = extras_action + self.delimiter = delimiter + self.quote_char = quote_char + self.escape_char = escape_char + self.line_terminator = line_terminator + self.quoting = quoting + + self._csv_writer: csv.DictWriter | None = None + + @property + def csv_writer(self) -> csv.DictWriter: + if self._csv_writer is None: + self._csv_writer = csv.DictWriter( + self.batch_export_file, + fieldnames=self.field_names, + extrasaction=self.extras_action, + delimiter=self.delimiter, + quotechar=self.quote_char, + escapechar=self.escape_char, + quoting=self.quoting, + lineterminator=self.line_terminator, + ) + + return self._csv_writer + + def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: + """Write records to a temporary file as CSV.""" + self.csv_writer.writerows(record_batch.to_pylist()) + + +class ParquetBatchExportWriter(BatchExportWriter): + """A `BatchExportWriter` for Apache Parquet format.""" + + def __init__( + self, + max_bytes: int, + flush_callable: FlushCallable, + schema: pa.Schema, + version: str = "2.6", + compression: str | None = "snappy", + compression_level: int | None = None, + ): + super().__init__( + max_bytes=max_bytes, + flush_callable=flush_callable, + file_kwargs={"compression": None}, # ParquetWriter handles compression + ) + self.schema = schema + self.version = version + self.compression = compression + self.compression_level = compression_level + + self._parquet_writer: pq.ParquetWriter | None = None + + @property + def parquet_writer(self) -> pq.ParquetWriter: + if self._parquet_writer is None: + self._parquet_writer = pq.ParquetWriter( + self.batch_export_file, + schema=self.schema, + version=self.version, + # Compression *can* be `None`. + compression=self.compression, + compression_level=self.compression_level, + ) + return self._parquet_writer + + def ensure_parquet_writer_is_closed(self) -> None: + """Ensure ParquetWriter is closed as Parquet footer bytes are written on closing.""" + if self._parquet_writer is None: + return + + self._parquet_writer.writer.close() + self._parquet_writer = None + + @contextlib.asynccontextmanager + async def open_temporary_file(self): + """Ensure underlying Parquet writer is closed before flushing and closing temporary file.""" + async with super().open_temporary_file(): + try: + yield + finally: + self.ensure_parquet_writer_is_closed() + + def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: + """Write records to a temporary file as Parquet.""" + self.parquet_writer.write_batch(record_batch.select(self.parquet_writer.schema.names)) diff --git a/posthog/temporal/tests/batch_exports/test_batch_exports.py b/posthog/temporal/tests/batch_exports/test_batch_exports.py index 0afbfcabb71cb..756c07e442e4f 100644 --- a/posthog/temporal/tests/batch_exports/test_batch_exports.py +++ b/posthog/temporal/tests/batch_exports/test_batch_exports.py @@ -1,6 +1,4 @@ -import csv import datetime as dt -import io import json import operator from random import randint @@ -9,11 +7,9 @@ from django.test import override_settings from posthog.temporal.batch_exports.batch_exports import ( - BatchExportTemporaryFile, get_data_interval, get_rows_count, iter_records, - json_dumps_bytes, ) from posthog.temporal.tests.utils.events import generate_test_events_in_clickhouse @@ -558,181 +554,3 @@ def test_get_data_interval(interval, data_interval_end, expected): """Test get_data_interval returns the expected data interval tuple.""" result = get_data_interval(interval, data_interval_end) assert result == expected - - -@pytest.mark.parametrize( - "to_write", - [ - (b"",), - (b"", b""), - (b"12345",), - (b"12345", b"12345"), - (b"abbcccddddeeeee",), - (b"abbcccddddeeeee", b"abbcccddddeeeee"), - ], -) -def test_batch_export_temporary_file_tracks_bytes(to_write): - """Test the bytes written by BatchExportTemporaryFile match expected.""" - with BatchExportTemporaryFile() as be_file: - for content in to_write: - be_file.write(content) - - assert be_file.bytes_total == sum(len(content) for content in to_write) - assert be_file.bytes_since_last_reset == sum(len(content) for content in to_write) - - be_file.reset() - - assert be_file.bytes_total == sum(len(content) for content in to_write) - assert be_file.bytes_since_last_reset == 0 - - -TEST_RECORDS = [ - [], - [ - {"id": "record-1", "property": "value", "property_int": 1}, - {"id": "record-2", "property": "another-value", "property_int": 2}, - { - "id": "record-3", - "property": {"id": "nested-record", "property": "nested-value"}, - "property_int": 3, - }, - ], -] - - -@pytest.mark.parametrize( - "records", - TEST_RECORDS, -) -def test_batch_export_temporary_file_write_records_to_jsonl(records): - """Test JSONL records written by BatchExportTemporaryFile match expected.""" - jsonl_dump = b"\n".join(map(json_dumps_bytes, records)) - - with BatchExportTemporaryFile() as be_file: - be_file.write_records_to_jsonl(records) - - assert be_file.bytes_total == len(jsonl_dump) - assert be_file.bytes_since_last_reset == len(jsonl_dump) - assert be_file.records_total == len(records) - assert be_file.records_since_last_reset == len(records) - - be_file.seek(0) - lines = be_file.readlines() - assert len(lines) == len(records) - - for line_index, jsonl_record in enumerate(lines): - json_loaded = json.loads(jsonl_record) - assert json_loaded == records[line_index] - - be_file.reset() - - assert be_file.bytes_total == len(jsonl_dump) - assert be_file.bytes_since_last_reset == 0 - assert be_file.records_total == len(records) - assert be_file.records_since_last_reset == 0 - - -@pytest.mark.parametrize( - "records", - TEST_RECORDS, -) -def test_batch_export_temporary_file_write_records_to_csv(records): - """Test CSV written by BatchExportTemporaryFile match expected.""" - in_memory_file_obj = io.StringIO() - writer = csv.DictWriter( - in_memory_file_obj, - fieldnames=records[0].keys() if len(records) > 0 else [], - delimiter=",", - quotechar='"', - escapechar="\\", - lineterminator="\n", - quoting=csv.QUOTE_NONE, - ) - writer.writerows(records) - - with BatchExportTemporaryFile(mode="w+") as be_file: - be_file.write_records_to_csv(records) - - assert be_file.bytes_total == in_memory_file_obj.tell() - assert be_file.bytes_since_last_reset == in_memory_file_obj.tell() - assert be_file.records_total == len(records) - assert be_file.records_since_last_reset == len(records) - - be_file.seek(0) - reader = csv.reader( - be_file._file, - delimiter=",", - quotechar='"', - escapechar="\\", - quoting=csv.QUOTE_NONE, - ) - - rows = [row for row in reader] - assert len(rows) == len(records) - - for row_index, csv_record in enumerate(rows): - for value_index, value in enumerate(records[row_index].values()): - # Everything returned by csv.reader is a str. - # This means type information is lost when writing to CSV - # but this just a limitation of the format. - assert csv_record[value_index] == str(value) - - be_file.reset() - - assert be_file.bytes_total == in_memory_file_obj.tell() - assert be_file.bytes_since_last_reset == 0 - assert be_file.records_total == len(records) - assert be_file.records_since_last_reset == 0 - - -@pytest.mark.parametrize( - "records", - TEST_RECORDS, -) -def test_batch_export_temporary_file_write_records_to_tsv(records): - """Test TSV written by BatchExportTemporaryFile match expected.""" - in_memory_file_obj = io.StringIO() - writer = csv.DictWriter( - in_memory_file_obj, - fieldnames=records[0].keys() if len(records) > 0 else [], - delimiter="\t", - quotechar='"', - escapechar="\\", - lineterminator="\n", - quoting=csv.QUOTE_NONE, - ) - writer.writerows(records) - - with BatchExportTemporaryFile(mode="w+") as be_file: - be_file.write_records_to_tsv(records) - - assert be_file.bytes_total == in_memory_file_obj.tell() - assert be_file.bytes_since_last_reset == in_memory_file_obj.tell() - assert be_file.records_total == len(records) - assert be_file.records_since_last_reset == len(records) - - be_file.seek(0) - reader = csv.reader( - be_file._file, - delimiter="\t", - quotechar='"', - escapechar="\\", - quoting=csv.QUOTE_NONE, - ) - - rows = [row for row in reader] - assert len(rows) == len(records) - - for row_index, csv_record in enumerate(rows): - for value_index, value in enumerate(records[row_index].values()): - # Everything returned by csv.reader is a str. - # This means type information is lost when writing to CSV - # but this just a limitation of the format. - assert csv_record[value_index] == str(value) - - be_file.reset() - - assert be_file.bytes_total == in_memory_file_obj.tell() - assert be_file.bytes_since_last_reset == 0 - assert be_file.records_total == len(records) - assert be_file.records_since_last_reset == 0 diff --git a/posthog/temporal/tests/batch_exports/test_temporary_file.py b/posthog/temporal/tests/batch_exports/test_temporary_file.py new file mode 100644 index 0000000000000..9754fe2c6e702 --- /dev/null +++ b/posthog/temporal/tests/batch_exports/test_temporary_file.py @@ -0,0 +1,188 @@ +import csv +import io +import json + +import pytest + +from posthog.temporal.batch_exports.temporary_file import ( + BatchExportTemporaryFile, + json_dumps_bytes, +) + + +@pytest.mark.parametrize( + "to_write", + [ + (b"",), + (b"", b""), + (b"12345",), + (b"12345", b"12345"), + (b"abbcccddddeeeee",), + (b"abbcccddddeeeee", b"abbcccddddeeeee"), + ], +) +def test_batch_export_temporary_file_tracks_bytes(to_write): + """Test the bytes written by BatchExportTemporaryFile match expected.""" + with BatchExportTemporaryFile() as be_file: + for content in to_write: + be_file.write(content) + + assert be_file.bytes_total == sum(len(content) for content in to_write) + assert be_file.bytes_since_last_reset == sum(len(content) for content in to_write) + + be_file.reset() + + assert be_file.bytes_total == sum(len(content) for content in to_write) + assert be_file.bytes_since_last_reset == 0 + + +TEST_RECORDS = [ + [], + [ + {"id": "record-1", "property": "value", "property_int": 1}, + {"id": "record-2", "property": "another-value", "property_int": 2}, + { + "id": "record-3", + "property": {"id": "nested-record", "property": "nested-value"}, + "property_int": 3, + }, + ], +] + + +@pytest.mark.parametrize( + "records", + TEST_RECORDS, +) +def test_batch_export_temporary_file_write_records_to_jsonl(records): + """Test JSONL records written by BatchExportTemporaryFile match expected.""" + jsonl_dump = b"\n".join(map(json_dumps_bytes, records)) + + with BatchExportTemporaryFile() as be_file: + be_file.write_records_to_jsonl(records) + + assert be_file.bytes_total == len(jsonl_dump) + assert be_file.bytes_since_last_reset == len(jsonl_dump) + assert be_file.records_total == len(records) + assert be_file.records_since_last_reset == len(records) + + be_file.seek(0) + lines = be_file.readlines() + assert len(lines) == len(records) + + for line_index, jsonl_record in enumerate(lines): + json_loaded = json.loads(jsonl_record) + assert json_loaded == records[line_index] + + be_file.reset() + + assert be_file.bytes_total == len(jsonl_dump) + assert be_file.bytes_since_last_reset == 0 + assert be_file.records_total == len(records) + assert be_file.records_since_last_reset == 0 + + +@pytest.mark.parametrize( + "records", + TEST_RECORDS, +) +def test_batch_export_temporary_file_write_records_to_csv(records): + """Test CSV written by BatchExportTemporaryFile match expected.""" + in_memory_file_obj = io.StringIO() + writer = csv.DictWriter( + in_memory_file_obj, + fieldnames=records[0].keys() if len(records) > 0 else [], + delimiter=",", + quotechar='"', + escapechar="\\", + lineterminator="\n", + quoting=csv.QUOTE_NONE, + ) + writer.writerows(records) + + with BatchExportTemporaryFile(mode="w+") as be_file: + be_file.write_records_to_csv(records) + + assert be_file.bytes_total == in_memory_file_obj.tell() + assert be_file.bytes_since_last_reset == in_memory_file_obj.tell() + assert be_file.records_total == len(records) + assert be_file.records_since_last_reset == len(records) + + be_file.seek(0) + reader = csv.reader( + be_file._file, + delimiter=",", + quotechar='"', + escapechar="\\", + quoting=csv.QUOTE_NONE, + ) + + rows = [row for row in reader] + assert len(rows) == len(records) + + for row_index, csv_record in enumerate(rows): + for value_index, value in enumerate(records[row_index].values()): + # Everything returned by csv.reader is a str. + # This means type information is lost when writing to CSV + # but this just a limitation of the format. + assert csv_record[value_index] == str(value) + + be_file.reset() + + assert be_file.bytes_total == in_memory_file_obj.tell() + assert be_file.bytes_since_last_reset == 0 + assert be_file.records_total == len(records) + assert be_file.records_since_last_reset == 0 + + +@pytest.mark.parametrize( + "records", + TEST_RECORDS, +) +def test_batch_export_temporary_file_write_records_to_tsv(records): + """Test TSV written by BatchExportTemporaryFile match expected.""" + in_memory_file_obj = io.StringIO() + writer = csv.DictWriter( + in_memory_file_obj, + fieldnames=records[0].keys() if len(records) > 0 else [], + delimiter="\t", + quotechar='"', + escapechar="\\", + lineterminator="\n", + quoting=csv.QUOTE_NONE, + ) + writer.writerows(records) + + with BatchExportTemporaryFile(mode="w+") as be_file: + be_file.write_records_to_tsv(records) + + assert be_file.bytes_total == in_memory_file_obj.tell() + assert be_file.bytes_since_last_reset == in_memory_file_obj.tell() + assert be_file.records_total == len(records) + assert be_file.records_since_last_reset == len(records) + + be_file.seek(0) + reader = csv.reader( + be_file._file, + delimiter="\t", + quotechar='"', + escapechar="\\", + quoting=csv.QUOTE_NONE, + ) + + rows = [row for row in reader] + assert len(rows) == len(records) + + for row_index, csv_record in enumerate(rows): + for value_index, value in enumerate(records[row_index].values()): + # Everything returned by csv.reader is a str. + # This means type information is lost when writing to CSV + # but this just a limitation of the format. + assert csv_record[value_index] == str(value) + + be_file.reset() + + assert be_file.bytes_total == in_memory_file_obj.tell() + assert be_file.bytes_since_last_reset == 0 + assert be_file.records_total == len(records) + assert be_file.records_since_last_reset == 0 From 340e0370af4a0b1139ea0ba28d5d75b1531aa09b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Thu, 21 Mar 2024 12:18:19 +0100 Subject: [PATCH 08/14] test: Add writer classes tests and docstrings --- .../temporal/batch_exports/temporary_file.py | 91 +++++--- .../batch_exports/test_temporary_file.py | 201 ++++++++++++++++++ 2 files changed, 264 insertions(+), 28 deletions(-) diff --git a/posthog/temporal/batch_exports/temporary_file.py b/posthog/temporal/batch_exports/temporary_file.py index 40dfc72bc64c9..23f66c5c2292f 100644 --- a/posthog/temporal/batch_exports/temporary_file.py +++ b/posthog/temporal/batch_exports/temporary_file.py @@ -78,6 +78,17 @@ def brotli_compressor(self): self._brotli_compressor = brotli.Compressor() return self._brotli_compressor + def finish_brotli_compressor(self): + """Flush remaining brotli bytes.""" + # TODO: Move compression out of `BatchExportTemporaryFile` to a standard class for all writers. + if self.compression != "brotli": + raise ValueError(f"Compression is '{self.compression}', not 'brotli'") + + result = self._file.write(self.brotli_compressor.finish()) + self.bytes_total += result + self.bytes_since_last_reset += result + self._brotli_compressor = None + def compress(self, content: bytes | str) -> bytes: if isinstance(content, str): encoded = content.encode("utf-8") @@ -188,14 +199,6 @@ def write_records_to_tsv( def rewind(self): """Rewind the file before reading it.""" - if self.compression == "brotli": - result = self._file.write(self.brotli_compressor.finish()) - - self.bytes_total += result - self.bytes_since_last_reset += result - - self._brotli_compressor = None - self._file.seek(0) def reset(self): @@ -231,36 +234,38 @@ class BatchExportWriter(abc.ABC): Actual writing calls are passed to the underlying `batch_export_file`. Attributes: - batch_export_file: The temporary file we are writing to. - bytes_flush_threshold: Flush the temporary file with the provided `flush_callable` + _batch_export_file: The temporary file we are writing to. + max_bytes: Flush the temporary file with the provided `flush_callable` upon reaching or surpassing this threshold. Keep in mind we write on a RecordBatch per RecordBatch basis, which means the threshold will be surpassed by at most the size of a RecordBatch before a flush occurs. - flush_callable: A callback to flush the temporary file when `bytes_flush_treshold` is reached. + flush_callable: A callback to flush the temporary file when `max_bytes` is reached. The temporary file will be reset after calling `flush_callable`. + file_kwargs: Optional keyword arguments passed when initializing `_batch_export_file`. + last_inserted_at: Latest `_inserted_at` written. This attribute leaks some implementation + details, as we are assuming assume `_inserted_at` is present, as it's added to all + batch export queries. records_total: The total number of records (not RecordBatches!) written. records_since_last_flush: The number of records written since last flush. - last_inserted_at: Latest `_inserted_at` written. This attribute leaks some implementation - details, as we are making two assumptions about the RecordBatches being written: - * We assume RecordBatches are sorted on `_inserted_at`, which currently happens with - an `ORDER BY` clause. - * We assume `_inserted_at` is present, as it's added to all batch export queries. + bytes_total: The total number of bytes written. + bytes_since_last_flush: The number of bytes written since last flush. """ def __init__( self, flush_callable: FlushCallable, max_bytes: int, - file_kwargs: collections.abc.Mapping[str, typing.Any], + file_kwargs: collections.abc.Mapping[str, typing.Any] | None = None, ): self.flush_callable = flush_callable self.max_bytes = max_bytes - self.file_kwargs = file_kwargs + self.file_kwargs: collections.abc.Mapping[str, typing.Any] = file_kwargs or {} self._batch_export_file: BatchExportTemporaryFile | None = None self.reset_writer_tracking() def reset_writer_tracking(self): + """Reset this writer's tracking state.""" self.last_inserted_at: dt.datetime | None = None self.records_total = 0 self.records_since_last_flush = 0 @@ -274,7 +279,8 @@ async def open_temporary_file(self): The underlying `BatchExportTemporaryFile` is only accessible within this context manager. This helps us separate the lifetime of the underlying temporary file from the writer: The writer may still be accessed even after the temporary file is closed, while on the other hand we ensure the file and all - its data is flushed and not leaked outside the context. Any relevant tracking information + its data is flushed and not leaked outside the context. Any relevant tracking information is copied + to the writer. """ self.reset_writer_tracking() @@ -297,6 +303,11 @@ async def open_temporary_file(self): @property def batch_export_file(self): + """Property for underlying temporary file. + + Raises: + ValueError: if attempting to access the temporary file before it has been opened. + """ if self._batch_export_file is None: raise ValueError("Batch export file is closed. Did you forget to call 'open_temporary_file'?") return self._batch_export_file @@ -311,16 +322,19 @@ def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: pass def track_records_written(self, record_batch: pa.RecordBatch) -> None: + """Update this writer's state with the number of records in `record_batch`.""" self.records_total += record_batch.num_rows self.records_since_last_flush += record_batch.num_rows def track_bytes_written(self, batch_export_file: BatchExportTemporaryFile) -> None: + """Update this writer's state with the bytes in `batch_export_file`.""" self.bytes_total = batch_export_file.bytes_total self.bytes_since_last_flush = batch_export_file.bytes_since_last_reset async def write_record_batch(self, record_batch: pa.RecordBatch) -> None: """Issue a record batch write tracking progress and flushing if required.""" - last_inserted_at = record_batch.column("_inserted_at")[0].as_py() + record_batch = record_batch.sort_by("_inserted_at") + last_inserted_at = record_batch.column("_inserted_at")[-1].as_py() column_names = record_batch.column_names column_names.pop(column_names.index("_inserted_at")) @@ -336,6 +350,11 @@ async def write_record_batch(self, record_batch: pa.RecordBatch) -> None: async def flush(self, last_inserted_at: dt.datetime, is_last: bool = False) -> None: """Call the provided `flush_callable` and reset underlying file.""" + if is_last is True and self.batch_export_file.compression == "brotli": + self.batch_export_file.finish_brotli_compressor() + + self.batch_export_file.seek(0) + await self.flush_callable( self.batch_export_file, self.records_since_last_flush, @@ -350,7 +369,12 @@ async def flush(self, last_inserted_at: dt.datetime, is_last: bool = False) -> N class JSONLBatchExportWriter(BatchExportWriter): - """A `BatchExportWriter` for JSONLines format.""" + """A `BatchExportWriter` for JSONLines format. + + Attributes: + default: The default function to use to cast non-serializable Python objects to serializable objects. + By default, non-serializable objects will be cast to string via `str()`. + """ def __init__( self, @@ -368,6 +392,7 @@ def __init__( self.default = default def write(self, content: bytes) -> int: + """Write a single row of JSONL.""" n = self.batch_export_file.write(orjson.dumps(content, default=str) + b"\n") return n @@ -430,16 +455,29 @@ def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: class ParquetBatchExportWriter(BatchExportWriter): - """A `BatchExportWriter` for Apache Parquet format.""" + """A `BatchExportWriter` for Apache Parquet format. + + We utilize and wrap a `pyarrow.parquet.ParquetWriter` to do the actual writing. We default to their + defaults for most parameters; however this class could be extended with more attributes to pass along + to `pyarrow.parquet.ParquetWriter`. + + See the pyarrow docs for more details on what parameters can the writer be configured with: + https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetWriter.html + + In contrast to other writers, instead of us handling compression we let `pyarrow.parquet.ParquetWriter` + handle it, so `BatchExportTemporaryFile` is always initialized with `compression=None`. + + Attributes: + schema: The schema used by the Parquet file. Should match the schema of written RecordBatches. + compression: Compression codec passed to underlying `pyarrow.parquet.ParquetWriter`. + """ def __init__( self, max_bytes: int, flush_callable: FlushCallable, schema: pa.Schema, - version: str = "2.6", compression: str | None = "snappy", - compression_level: int | None = None, ): super().__init__( max_bytes=max_bytes, @@ -447,9 +485,7 @@ def __init__( file_kwargs={"compression": None}, # ParquetWriter handles compression ) self.schema = schema - self.version = version self.compression = compression - self.compression_level = compression_level self._parquet_writer: pq.ParquetWriter | None = None @@ -459,10 +495,8 @@ def parquet_writer(self) -> pq.ParquetWriter: self._parquet_writer = pq.ParquetWriter( self.batch_export_file, schema=self.schema, - version=self.version, # Compression *can* be `None`. compression=self.compression, - compression_level=self.compression_level, ) return self._parquet_writer @@ -485,4 +519,5 @@ async def open_temporary_file(self): def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: """Write records to a temporary file as Parquet.""" + self.parquet_writer.write_batch(record_batch.select(self.parquet_writer.schema.names)) diff --git a/posthog/temporal/tests/batch_exports/test_temporary_file.py b/posthog/temporal/tests/batch_exports/test_temporary_file.py index 9754fe2c6e702..4fd7e69c0c12f 100644 --- a/posthog/temporal/tests/batch_exports/test_temporary_file.py +++ b/posthog/temporal/tests/batch_exports/test_temporary_file.py @@ -1,11 +1,17 @@ import csv +import datetime as dt import io import json +import pyarrow as pa +import pyarrow.parquet as pq import pytest from posthog.temporal.batch_exports.temporary_file import ( BatchExportTemporaryFile, + CSVBatchExportWriter, + JSONLBatchExportWriter, + ParquetBatchExportWriter, json_dumps_bytes, ) @@ -186,3 +192,198 @@ def test_batch_export_temporary_file_write_records_to_tsv(records): assert be_file.bytes_since_last_reset == 0 assert be_file.records_total == len(records) assert be_file.records_since_last_reset == 0 + + +TEST_RECORD_BATCHES = [ + pa.RecordBatch.from_pydict( + { + "event": pa.array(["test-event-0", "test-event-1", "test-event-2"]), + "properties": pa.array(['{"prop_0": 1, "prop_1": 2}', "{}", "null"]), + "_inserted_at": pa.array([0, 1, 2]), + } + ) +] + + +@pytest.mark.parametrize( + "record_batch", + TEST_RECORD_BATCHES, +) +@pytest.mark.asyncio +async def test_jsonl_writer_writes_record_batches(record_batch): + """Test record batches are written as valid JSONL.""" + in_memory_file_obj = io.BytesIO() + inserted_ats_seen = [] + + async def store_in_memory_on_flush( + batch_export_file, records_since_last_flush, bytes_since_last_flush, last_inserted_at, is_last + ): + in_memory_file_obj.write(batch_export_file.read()) + inserted_ats_seen.append(last_inserted_at) + + writer = JSONLBatchExportWriter(max_bytes=1, flush_callable=store_in_memory_on_flush) + + record_batch = record_batch.sort_by("_inserted_at") + async with writer.open_temporary_file(): + await writer.write_record_batch(record_batch) + + lines = in_memory_file_obj.readlines() + for index, line in enumerate(lines): + written_jsonl = json.loads(line) + + single_record_batch = record_batch.slice(offset=index, length=1) + expected_jsonl = single_record_batch.to_pylist()[0] + + assert "_inserted_at" not in written_jsonl + assert written_jsonl == expected_jsonl + + assert inserted_ats_seen == [record_batch.column("_inserted_at")[-1].as_py()] + + +@pytest.mark.parametrize( + "record_batch", + TEST_RECORD_BATCHES, +) +@pytest.mark.asyncio +async def test_csv_writer_writes_record_batches(record_batch): + """Test record batches are written as valid CSV.""" + in_memory_file_obj = io.StringIO() + inserted_ats_seen = [] + + async def store_in_memory_on_flush( + batch_export_file, records_since_last_flush, bytes_since_last_flush, last_inserted_at, is_last + ): + in_memory_file_obj.write(batch_export_file.read().decode("utf-8")) + inserted_ats_seen.append(last_inserted_at) + + schema_columns = [column_name for column_name in record_batch.column_names if column_name != "_inserted_at"] + writer = CSVBatchExportWriter(max_bytes=1, field_names=schema_columns, flush_callable=store_in_memory_on_flush) + + record_batch = record_batch.sort_by("_inserted_at") + async with writer.open_temporary_file(): + await writer.write_record_batch(record_batch) + + reader = csv.reader( + in_memory_file_obj, + delimiter=",", + quotechar='"', + escapechar="\\", + quoting=csv.QUOTE_NONE, + ) + for index, written_csv_row in enumerate(reader): + single_record_batch = record_batch.slice(offset=index, length=1) + expected_csv = single_record_batch.to_pylist()[0] + + assert "_inserted_at" not in written_csv_row + assert written_csv_row == expected_csv + + assert inserted_ats_seen == [record_batch.column("_inserted_at")[-1].as_py()] + + +@pytest.mark.parametrize( + "record_batch", + TEST_RECORD_BATCHES, +) +@pytest.mark.asyncio +async def test_parquet_writer_writes_record_batches(record_batch): + """Test record batches are written as valid Parquet.""" + in_memory_file_obj = io.BytesIO() + inserted_ats_seen = [] + + async def store_in_memory_on_flush( + batch_export_file, records_since_last_flush, bytes_since_last_flush, last_inserted_at, is_last + ): + in_memory_file_obj.write(batch_export_file.read()) + inserted_ats_seen.append(last_inserted_at) + + schema_columns = [column_name for column_name in record_batch.column_names if column_name != "_inserted_at"] + + writer = ParquetBatchExportWriter( + max_bytes=1, + flush_callable=store_in_memory_on_flush, + schema=record_batch.select(schema_columns).schema, + ) + + record_batch = record_batch.sort_by("_inserted_at") + async with writer.open_temporary_file(): + await writer.write_record_batch(record_batch) + + written_parquet = pq.read_table(in_memory_file_obj) + + for index, written_row_as_dict in enumerate(written_parquet.to_pylist()): + single_record_batch = record_batch.slice(offset=index, length=1) + expected_row_as_dict = single_record_batch.select(schema_columns).to_pylist()[0] + + assert "_inserted_at" not in written_row_as_dict + assert written_row_as_dict == expected_row_as_dict + + # NOTE: Parquet gets flushed twice due to the extra flush at the end for footer bytes, so our mock function + # will see this value twice. + assert inserted_ats_seen == [ + record_batch.column("_inserted_at")[-1].as_py(), + record_batch.column("_inserted_at")[-1].as_py(), + ] + + +@pytest.mark.parametrize( + "record_batch", + TEST_RECORD_BATCHES, +) +@pytest.mark.asyncio +async def test_writing_out_of_scope_of_temporary_file_raises(record_batch): + """Test attempting a write out of temporary file scope raises a `ValueError`.""" + + async def do_nothing(*args, **kwargs): + pass + + schema_columns = [column_name for column_name in record_batch.column_names if column_name != "_inserted_at"] + writer = ParquetBatchExportWriter( + max_bytes=10, + flush_callable=do_nothing, + schema=record_batch.select(schema_columns).schema, + ) + + async with writer.open_temporary_file(): + pass + + with pytest.raises(ValueError, match="Batch export file is closed"): + await writer.write_record_batch(record_batch) + + +@pytest.mark.parametrize( + "record_batch", + TEST_RECORD_BATCHES, +) +@pytest.mark.asyncio +async def test_flushing_parquet_writer_resets_underlying_file(record_batch): + """Test flushing a writer resets underlying file.""" + flush_counter = 0 + + async def track_flushes(*args, **kwargs): + nonlocal flush_counter + flush_counter += 1 + + schema_columns = [column_name for column_name in record_batch.column_names if column_name != "_inserted_at"] + writer = ParquetBatchExportWriter( + max_bytes=10000000, + flush_callable=track_flushes, + schema=record_batch.select(schema_columns).schema, + ) + + async with writer.open_temporary_file(): + await writer.write_record_batch(record_batch) + + assert writer.batch_export_file.tell() > 0 + assert writer.bytes_since_last_flush > 0 + assert writer.bytes_since_last_flush == writer.batch_export_file.bytes_since_last_reset + assert writer.records_since_last_flush == record_batch.num_rows + + await writer.flush(dt.datetime.now()) + + assert flush_counter == 1 + assert writer.batch_export_file.tell() == 0 + assert writer.bytes_since_last_flush == 0 + assert writer.bytes_since_last_flush == writer.batch_export_file.bytes_since_last_reset + assert writer.records_since_last_flush == 0 + + assert flush_counter == 2 From c1ffe130049b77a06e3147e7bb30155b1bee2041 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Fri, 22 Mar 2024 10:34:57 +0100 Subject: [PATCH 09/14] feat: Add new type aliases and docstrings for FlushCallable --- .../temporal/batch_exports/temporary_file.py | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/posthog/temporal/batch_exports/temporary_file.py b/posthog/temporal/batch_exports/temporary_file.py index 23f66c5c2292f..6a59438e64029 100644 --- a/posthog/temporal/batch_exports/temporary_file.py +++ b/posthog/temporal/batch_exports/temporary_file.py @@ -213,8 +213,13 @@ def reset(self): self.records_since_last_reset = 0 +LastInsertedAt = dt.datetime +IsLast = bool +RecordsSinceLastFlush = int +BytesSinceLastFlush = int FlushCallable = collections.abc.Callable[ - [BatchExportTemporaryFile, int, int, dt.datetime, bool], collections.abc.Awaitable[None] + [BatchExportTemporaryFile, RecordsSinceLastFlush, BytesSinceLastFlush, LastInsertedAt, IsLast], + collections.abc.Awaitable[None], ] @@ -240,7 +245,11 @@ class BatchExportWriter(abc.ABC): per RecordBatch basis, which means the threshold will be surpassed by at most the size of a RecordBatch before a flush occurs. flush_callable: A callback to flush the temporary file when `max_bytes` is reached. - The temporary file will be reset after calling `flush_callable`. + The temporary file will be reset after calling `flush_callable`. When calling + `flush_callable` the following positional arguments will be passed: The temporary file + that must be flushed, the number of records since the last flush, the number of bytes + since the last flush, the latest recorded `_inserted_at`, and a `bool` indicating if + this is the last flush (when exiting the context manager). file_kwargs: Optional keyword arguments passed when initializing `_batch_export_file`. last_inserted_at: Latest `_inserted_at` written. This attribute leaks some implementation details, as we are assuming assume `_inserted_at` is present, as it's added to all @@ -349,7 +358,10 @@ async def write_record_batch(self, record_batch: pa.RecordBatch) -> None: await self.flush(last_inserted_at) async def flush(self, last_inserted_at: dt.datetime, is_last: bool = False) -> None: - """Call the provided `flush_callable` and reset underlying file.""" + """Call the provided `flush_callable` and reset underlying file. + + The underlying batch export temporary file will be reset after calling `flush_callable`. + """ if is_last is True and self.batch_export_file.compression == "brotli": self.batch_export_file.finish_brotli_compressor() From 626169dc44546ca1009247f122ecd19914aee6a4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Fri, 22 Mar 2024 10:43:46 +0100 Subject: [PATCH 10/14] refactor: Get rid of ensure close method --- posthog/temporal/batch_exports/temporary_file.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/posthog/temporal/batch_exports/temporary_file.py b/posthog/temporal/batch_exports/temporary_file.py index 6a59438e64029..abcd7b60ca0d4 100644 --- a/posthog/temporal/batch_exports/temporary_file.py +++ b/posthog/temporal/batch_exports/temporary_file.py @@ -512,14 +512,6 @@ def parquet_writer(self) -> pq.ParquetWriter: ) return self._parquet_writer - def ensure_parquet_writer_is_closed(self) -> None: - """Ensure ParquetWriter is closed as Parquet footer bytes are written on closing.""" - if self._parquet_writer is None: - return - - self._parquet_writer.writer.close() - self._parquet_writer = None - @contextlib.asynccontextmanager async def open_temporary_file(self): """Ensure underlying Parquet writer is closed before flushing and closing temporary file.""" @@ -527,7 +519,9 @@ async def open_temporary_file(self): try: yield finally: - self.ensure_parquet_writer_is_closed() + if self._parquet_writer is not None: + self._parquet_writer.writer.close() + self._parquet_writer = None def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: """Write records to a temporary file as Parquet.""" From 902f21d16a31e04f2d85589e4bfac2edbf5e9223 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Fri, 22 Mar 2024 11:27:28 +0100 Subject: [PATCH 11/14] fix: Use proper 'none' compression --- posthog/temporal/batch_exports/temporary_file.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/posthog/temporal/batch_exports/temporary_file.py b/posthog/temporal/batch_exports/temporary_file.py index abcd7b60ca0d4..f955f45553727 100644 --- a/posthog/temporal/batch_exports/temporary_file.py +++ b/posthog/temporal/batch_exports/temporary_file.py @@ -507,8 +507,7 @@ def parquet_writer(self) -> pq.ParquetWriter: self._parquet_writer = pq.ParquetWriter( self.batch_export_file, schema=self.schema, - # Compression *can* be `None`. - compression=self.compression, + compression="none" if self.compression is None else self.compression, ) return self._parquet_writer From 73469ac23b5b7abb5da466e96dc4dc7f909a0582 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Fri, 22 Mar 2024 11:27:54 +0100 Subject: [PATCH 12/14] refactor: Cover all possible file formats with FILE_FORMAT_EXTENSIONS.keys() --- .../tests/batch_exports/test_s3_batch_export_workflow.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py index 8c700fb191f2b..dcfeb542d6f63 100644 --- a/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py @@ -29,6 +29,7 @@ update_export_run_status, ) from posthog.temporal.batch_exports.s3_batch_export import ( + FILE_FORMAT_EXTENSIONS, HeartbeatDetails, S3BatchExportInputs, S3BatchExportWorkflow, @@ -311,7 +312,7 @@ async def assert_clickhouse_records_in_s3( @pytest.mark.parametrize("compression", [None, "gzip", "brotli"], indirect=True) @pytest.mark.parametrize("exclude_events", [None, ["test-exclude"]], indirect=True) @pytest.mark.parametrize("batch_export_schema", TEST_S3_SCHEMAS) -@pytest.mark.parametrize("file_format", ["JSONLines", "Parquet"]) +@pytest.mark.parametrize("file_format", FILE_FORMAT_EXTENSIONS.keys()) async def test_insert_into_s3_activity_puts_data_into_s3( clickhouse_client, bucket_name, @@ -474,7 +475,7 @@ async def s3_batch_export( @pytest.mark.parametrize("compression", [None, "gzip", "brotli"], indirect=True) @pytest.mark.parametrize("exclude_events", [None, ["test-exclude"]], indirect=True) @pytest.mark.parametrize("batch_export_schema", TEST_S3_SCHEMAS) -@pytest.mark.parametrize("file_format", ["JSONLines", "Parquet"], indirect=True) +@pytest.mark.parametrize("file_format", FILE_FORMAT_EXTENSIONS.keys(), indirect=True) async def test_s3_export_workflow_with_minio_bucket( clickhouse_client, minio_client, @@ -604,7 +605,7 @@ async def s3_client(bucket_name, s3_key_prefix): @pytest.mark.parametrize("encryption", [None, "AES256", "aws:kms"], indirect=True) @pytest.mark.parametrize("bucket_name", [os.getenv("S3_TEST_BUCKET")], indirect=True) @pytest.mark.parametrize("batch_export_schema", TEST_S3_SCHEMAS) -@pytest.mark.parametrize("file_format", ["JSONLines", "Parquet"]) +@pytest.mark.parametrize("file_format", FILE_FORMAT_EXTENSIONS.keys(), indirect=True) async def test_s3_export_workflow_with_s3_bucket( s3_client, clickhouse_client, From d015fc81eae19018e7a8132fa5cdfbd0bc5b1336 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Fri, 22 Mar 2024 11:28:14 +0100 Subject: [PATCH 13/14] test: Also check if bucket name is set to use S3 --- .../tests/batch_exports/test_s3_batch_export_workflow.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py index dcfeb542d6f63..e6583d049e2a8 100644 --- a/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py @@ -162,7 +162,14 @@ async def read_parquet_from_s3(bucket_name: str, key: str, json_columns) -> list ) else: - s3 = fs.S3FileSystem() + if os.getenv("S3_TEST_BUCKET") is not None: + s3 = fs.S3FileSystem() + else: + s3 = fs.S3FileSystem( + access_key="object_storage_root_user", + secret_key="object_storage_root_password", + endpoint_override=settings.OBJECT_STORAGE_ENDPOINT, + ) table = pq.read_table(f"{bucket_name}/{key}", filesystem=s3) From 892140383fa37cceb0d5e78c213af8cbb100166c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Fri, 22 Mar 2024 11:33:06 +0100 Subject: [PATCH 14/14] feat: Typing and docstring for get_batch_export_writer --- posthog/temporal/batch_exports/s3_batch_export.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/posthog/temporal/batch_exports/s3_batch_export.py b/posthog/temporal/batch_exports/s3_batch_export.py index 70fa3e3e71490..e83fe3f12915d 100644 --- a/posthog/temporal/batch_exports/s3_batch_export.py +++ b/posthog/temporal/batch_exports/s3_batch_export.py @@ -553,8 +553,13 @@ async def flush_to_s3( def get_batch_export_writer( - inputs, flush_callable: FlushCallable, max_bytes: int, schema: pa.Schema | None = None + inputs: S3InsertInputs, flush_callable: FlushCallable, max_bytes: int, schema: pa.Schema | None = None ) -> BatchExportWriter: + """Return the `BatchExportWriter` corresponding to configured `file_format`. + + Raises: + UnsupportedFileFormatError: If no writer exists for given `file_format`. + """ writer: BatchExportWriter if inputs.file_format == "Parquet":