From 7c0258cc0e5a346736fdf4a67801810a49b8d8d2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Tom=C3=A1s=20Far=C3=ADas=20Santana?= Date: Fri, 22 Mar 2024 14:20:13 +0100 Subject: [PATCH] refactor: Support for multiple file formats in S3 batch exports (#20979) * refactor: Support for multiple file formats in batch exports * refactor: Prefer composition over inheritance * refactor: More clearly separate writer from temporary file We now should be more explicit about what is the context in which the batch export temporary file is alive. The writer outlives this context, so it can be used by callers to, for example, check how many records were written. * test: More parquet testing * Update query snapshots * fix: Typing * refactor: Move temporary file to new module * test: Add writer classes tests and docstrings * feat: Add new type aliases and docstrings for FlushCallable * refactor: Get rid of ensure close method * fix: Use proper 'none' compression * refactor: Cover all possible file formats with FILE_FORMAT_EXTENSIONS.keys() * test: Also check if bucket name is set to use S3 * feat: Typing and docstring for get_batch_export_writer --------- Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com> --- posthog/batch_exports/service.py | 1 + .../test/__snapshots__/test_in_cohort.ambr | 8 +- .../temporal/batch_exports/batch_exports.py | 201 ------- .../batch_exports/bigquery_batch_export.py | 4 +- .../batch_exports/http_batch_export.py | 6 +- .../batch_exports/postgres_batch_export.py | 4 +- .../temporal/batch_exports/s3_batch_export.py | 203 +++++-- .../batch_exports/snowflake_batch_export.py | 4 +- .../temporal/batch_exports/temporary_file.py | 528 ++++++++++++++++++ .../tests/batch_exports/test_batch_exports.py | 182 ------ .../test_s3_batch_export_workflow.py | 159 +++++- .../batch_exports/test_temporary_file.py | 389 +++++++++++++ 12 files changed, 1238 insertions(+), 451 deletions(-) create mode 100644 posthog/temporal/batch_exports/temporary_file.py create mode 100644 posthog/temporal/tests/batch_exports/test_temporary_file.py diff --git a/posthog/batch_exports/service.py b/posthog/batch_exports/service.py index b00f0f4c98c69..d51dfdb2fbc3c 100644 --- a/posthog/batch_exports/service.py +++ b/posthog/batch_exports/service.py @@ -90,6 +90,7 @@ class S3BatchExportInputs: kms_key_id: str | None = None batch_export_schema: BatchExportSchema | None = None endpoint_url: str | None = None + file_format: str = "JSONLines" @dataclass diff --git a/posthog/hogql/transforms/test/__snapshots__/test_in_cohort.ambr b/posthog/hogql/transforms/test/__snapshots__/test_in_cohort.ambr index 9ff7f8ee0ab49..e0f5ea847110d 100644 --- a/posthog/hogql/transforms/test/__snapshots__/test_in_cohort.ambr +++ b/posthog/hogql/transforms/test/__snapshots__/test_in_cohort.ambr @@ -31,7 +31,7 @@ FROM events LEFT JOIN ( SELECT person_static_cohort.person_id AS cohort_person_id, 1 AS matched, person_static_cohort.cohort_id AS cohort_id FROM person_static_cohort - WHERE and(equals(person_static_cohort.team_id, 420), in(person_static_cohort.cohort_id, [12]))) AS __in_cohort ON equals(__in_cohort.cohort_person_id, events.person_id) + WHERE and(equals(person_static_cohort.team_id, 420), in(person_static_cohort.cohort_id, [11]))) AS __in_cohort ON equals(__in_cohort.cohort_person_id, events.person_id) WHERE and(equals(events.team_id, 420), 1, ifNull(equals(__in_cohort.matched, 1), 0)) LIMIT 100 SETTINGS readonly=2, max_execution_time=60, allow_experimental_object_type=1 @@ -42,7 +42,7 @@ FROM events LEFT JOIN ( SELECT person_id AS cohort_person_id, 1 AS matched, cohort_id FROM static_cohort_people - WHERE in(cohort_id, [12])) AS __in_cohort ON equals(__in_cohort.cohort_person_id, person_id) + WHERE in(cohort_id, [11])) AS __in_cohort ON equals(__in_cohort.cohort_person_id, person_id) WHERE and(1, equals(__in_cohort.matched, 1)) LIMIT 100 ''' @@ -55,7 +55,7 @@ FROM events LEFT JOIN ( SELECT person_static_cohort.person_id AS cohort_person_id, 1 AS matched, person_static_cohort.cohort_id AS cohort_id FROM person_static_cohort - WHERE and(equals(person_static_cohort.team_id, 420), in(person_static_cohort.cohort_id, [13]))) AS __in_cohort ON equals(__in_cohort.cohort_person_id, events.person_id) + WHERE and(equals(person_static_cohort.team_id, 420), in(person_static_cohort.cohort_id, [12]))) AS __in_cohort ON equals(__in_cohort.cohort_person_id, events.person_id) WHERE and(equals(events.team_id, 420), 1, ifNull(equals(__in_cohort.matched, 1), 0)) LIMIT 100 SETTINGS readonly=2, max_execution_time=60, allow_experimental_object_type=1 @@ -66,7 +66,7 @@ FROM events LEFT JOIN ( SELECT person_id AS cohort_person_id, 1 AS matched, cohort_id FROM static_cohort_people - WHERE in(cohort_id, [13])) AS __in_cohort ON equals(__in_cohort.cohort_person_id, person_id) + WHERE in(cohort_id, [12])) AS __in_cohort ON equals(__in_cohort.cohort_person_id, person_id) WHERE and(1, equals(__in_cohort.matched, 1)) LIMIT 100 ''' diff --git a/posthog/temporal/batch_exports/batch_exports.py b/posthog/temporal/batch_exports/batch_exports.py index c40950c654426..88cf9e32f274f 100644 --- a/posthog/temporal/batch_exports/batch_exports.py +++ b/posthog/temporal/batch_exports/batch_exports.py @@ -1,15 +1,10 @@ import collections.abc -import csv import dataclasses import datetime as dt -import gzip -import tempfile import typing import uuid from string import Template -import brotli -import orjson import pyarrow as pa from asgiref.sync import sync_to_async from django.conf import settings @@ -286,202 +281,6 @@ def get_data_interval(interval: str, data_interval_end: str | None) -> tuple[dt. return (data_interval_start_dt, data_interval_end_dt) -def json_dumps_bytes(d) -> bytes: - return orjson.dumps(d, default=str) - - -class BatchExportTemporaryFile: - """A TemporaryFile used to as an intermediate step while exporting data. - - This class does not implement the file-like interface but rather passes any calls - to the underlying tempfile.NamedTemporaryFile. We do override 'write' methods - to allow tracking bytes and records. - """ - - def __init__( - self, - mode: str = "w+b", - buffering=-1, - compression: str | None = None, - encoding: str | None = None, - newline: str | None = None, - suffix: str | None = None, - prefix: str | None = None, - dir: str | None = None, - *, - errors: str | None = None, - ): - self._file = tempfile.NamedTemporaryFile( - mode=mode, - encoding=encoding, - newline=newline, - buffering=buffering, - suffix=suffix, - prefix=prefix, - dir=dir, - errors=errors, - ) - self.compression = compression - self.bytes_total = 0 - self.records_total = 0 - self.bytes_since_last_reset = 0 - self.records_since_last_reset = 0 - self._brotli_compressor = None - - def __getattr__(self, name): - """Pass get attr to underlying tempfile.NamedTemporaryFile.""" - return self._file.__getattr__(name) - - def __enter__(self): - """Context-manager protocol enter method.""" - self._file.__enter__() - return self - - def __exit__(self, exc, value, tb): - """Context-manager protocol exit method.""" - return self._file.__exit__(exc, value, tb) - - def __iter__(self): - yield from self._file - - @property - def brotli_compressor(self): - if self._brotli_compressor is None: - self._brotli_compressor = brotli.Compressor() - return self._brotli_compressor - - def compress(self, content: bytes | str) -> bytes: - if isinstance(content, str): - encoded = content.encode("utf-8") - else: - encoded = content - - match self.compression: - case "gzip": - return gzip.compress(encoded) - case "brotli": - self.brotli_compressor.process(encoded) - return self.brotli_compressor.flush() - case None: - return encoded - case _: - raise ValueError(f"Unsupported compression: '{self.compression}'") - - def write(self, content: bytes | str): - """Write bytes to underlying file keeping track of how many bytes were written.""" - compressed_content = self.compress(content) - - if "b" in self.mode: - result = self._file.write(compressed_content) - else: - result = self._file.write(compressed_content.decode("utf-8")) - - self.bytes_total += result - self.bytes_since_last_reset += result - - return result - - def write_record_as_bytes(self, record: bytes): - result = self.write(record) - - self.records_total += 1 - self.records_since_last_reset += 1 - - return result - - def write_records_to_jsonl(self, records): - """Write records to a temporary file as JSONL.""" - if len(records) == 1: - jsonl_dump = orjson.dumps(records[0], option=orjson.OPT_APPEND_NEWLINE, default=str) - else: - jsonl_dump = b"\n".join(map(json_dumps_bytes, records)) - - result = self.write(jsonl_dump) - - self.records_total += len(records) - self.records_since_last_reset += len(records) - - return result - - def write_records_to_csv( - self, - records, - fieldnames: None | collections.abc.Sequence[str] = None, - extrasaction: typing.Literal["raise", "ignore"] = "ignore", - delimiter: str = ",", - quotechar: str = '"', - escapechar: str | None = "\\", - lineterminator: str = "\n", - quoting=csv.QUOTE_NONE, - ): - """Write records to a temporary file as CSV.""" - if len(records) == 0: - return - - if fieldnames is None: - fieldnames = list(records[0].keys()) - - writer = csv.DictWriter( - self, - fieldnames=fieldnames, - extrasaction=extrasaction, - delimiter=delimiter, - quotechar=quotechar, - escapechar=escapechar, - quoting=quoting, - lineterminator=lineterminator, - ) - writer.writerows(records) - - self.records_total += len(records) - self.records_since_last_reset += len(records) - - def write_records_to_tsv( - self, - records, - fieldnames: None | list[str] = None, - extrasaction: typing.Literal["raise", "ignore"] = "ignore", - quotechar: str = '"', - escapechar: str | None = "\\", - lineterminator: str = "\n", - quoting=csv.QUOTE_NONE, - ): - """Write records to a temporary file as TSV.""" - return self.write_records_to_csv( - records, - fieldnames=fieldnames, - extrasaction=extrasaction, - delimiter="\t", - quotechar=quotechar, - escapechar=escapechar, - quoting=quoting, - lineterminator=lineterminator, - ) - - def rewind(self): - """Rewind the file before reading it.""" - if self.compression == "brotli": - result = self._file.write(self.brotli_compressor.finish()) - - self.bytes_total += result - self.bytes_since_last_reset += result - - self._brotli_compressor = None - - self._file.seek(0) - - def reset(self): - """Reset underlying file by truncating it. - - Also resets the tracker attributes for bytes and records since last reset. - """ - self._file.seek(0) - self._file.truncate() - - self.bytes_since_last_reset = 0 - self.records_since_last_reset = 0 - - @dataclasses.dataclass class CreateBatchExportRunInputs: """Inputs to the create_export_run activity. diff --git a/posthog/temporal/batch_exports/bigquery_batch_export.py b/posthog/temporal/batch_exports/bigquery_batch_export.py index a0469de79bb9e..b754a7add16b4 100644 --- a/posthog/temporal/batch_exports/bigquery_batch_export.py +++ b/posthog/temporal/batch_exports/bigquery_batch_export.py @@ -15,7 +15,6 @@ from posthog.batch_exports.service import BatchExportField, BatchExportSchema, BigQueryBatchExportInputs from posthog.temporal.batch_exports.base import PostHogWorkflow from posthog.temporal.batch_exports.batch_exports import ( - BatchExportTemporaryFile, CreateBatchExportRunInputs, UpdateBatchExportRunStatusInputs, create_export_run, @@ -29,6 +28,9 @@ get_bytes_exported_metric, get_rows_exported_metric, ) +from posthog.temporal.batch_exports.temporary_file import ( + BatchExportTemporaryFile, +) from posthog.temporal.batch_exports.utils import peek_first_and_rewind from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.logger import bind_temporal_worker_logger diff --git a/posthog/temporal/batch_exports/http_batch_export.py b/posthog/temporal/batch_exports/http_batch_export.py index 8aca65c80ff38..2866d50c99876 100644 --- a/posthog/temporal/batch_exports/http_batch_export.py +++ b/posthog/temporal/batch_exports/http_batch_export.py @@ -13,7 +13,6 @@ from posthog.models import BatchExportRun from posthog.temporal.batch_exports.base import PostHogWorkflow from posthog.temporal.batch_exports.batch_exports import ( - BatchExportTemporaryFile, CreateBatchExportRunInputs, UpdateBatchExportRunStatusInputs, create_export_run, @@ -21,12 +20,15 @@ get_data_interval, get_rows_count, iter_records, - json_dumps_bytes, ) from posthog.temporal.batch_exports.metrics import ( get_bytes_exported_metric, get_rows_exported_metric, ) +from posthog.temporal.batch_exports.temporary_file import ( + BatchExportTemporaryFile, + json_dumps_bytes, +) from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.logger import bind_temporal_worker_logger diff --git a/posthog/temporal/batch_exports/postgres_batch_export.py b/posthog/temporal/batch_exports/postgres_batch_export.py index 5dbfc6faa4acf..98969ee78de79 100644 --- a/posthog/temporal/batch_exports/postgres_batch_export.py +++ b/posthog/temporal/batch_exports/postgres_batch_export.py @@ -17,7 +17,6 @@ from posthog.batch_exports.service import BatchExportField, BatchExportSchema, PostgresBatchExportInputs from posthog.temporal.batch_exports.base import PostHogWorkflow from posthog.temporal.batch_exports.batch_exports import ( - BatchExportTemporaryFile, CreateBatchExportRunInputs, UpdateBatchExportRunStatusInputs, create_export_run, @@ -31,6 +30,9 @@ get_bytes_exported_metric, get_rows_exported_metric, ) +from posthog.temporal.batch_exports.temporary_file import ( + BatchExportTemporaryFile, +) from posthog.temporal.batch_exports.utils import peek_first_and_rewind from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.logger import bind_temporal_worker_logger diff --git a/posthog/temporal/batch_exports/s3_batch_export.py b/posthog/temporal/batch_exports/s3_batch_export.py index 4d99cbeffd7c3..e83fe3f12915d 100644 --- a/posthog/temporal/batch_exports/s3_batch_export.py +++ b/posthog/temporal/batch_exports/s3_batch_export.py @@ -1,4 +1,5 @@ import asyncio +import collections.abc import contextlib import datetime as dt import io @@ -8,6 +9,8 @@ from dataclasses import dataclass import aioboto3 +import orjson +import pyarrow as pa from django.conf import settings from temporalio import activity, workflow from temporalio.common import RetryPolicy @@ -16,7 +19,6 @@ from posthog.batch_exports.service import BatchExportField, BatchExportSchema, S3BatchExportInputs from posthog.temporal.batch_exports.base import PostHogWorkflow from posthog.temporal.batch_exports.batch_exports import ( - BatchExportTemporaryFile, CreateBatchExportRunInputs, UpdateBatchExportRunStatusInputs, create_export_run, @@ -30,6 +32,15 @@ get_bytes_exported_metric, get_rows_exported_metric, ) +from posthog.temporal.batch_exports.temporary_file import ( + BatchExportTemporaryFile, + BatchExportWriter, + FlushCallable, + JSONLBatchExportWriter, + ParquetBatchExportWriter, + UnsupportedFileFormatError, +) +from posthog.temporal.batch_exports.utils import peek_first_and_rewind from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.logger import bind_temporal_worker_logger @@ -50,19 +61,31 @@ def get_allowed_template_variables(inputs) -> dict[str, str]: } +FILE_FORMAT_EXTENSIONS = { + "Parquet": "parquet", + "JSONLines": "jsonl", +} + +COMPRESSION_EXTENSIONS = { + "gzip": "gz", + "snappy": "sz", + "brotli": "br", + "ztsd": "zst", + "lz4": "lz4", +} + + def get_s3_key(inputs) -> str: """Return an S3 key given S3InsertInputs.""" template_variables = get_allowed_template_variables(inputs) key_prefix = inputs.prefix.format(**template_variables) + file_extension = FILE_FORMAT_EXTENSIONS[inputs.file_format] base_file_name = f"{inputs.data_interval_start}-{inputs.data_interval_end}" - match inputs.compression: - case "gzip": - file_name = base_file_name + ".jsonl.gz" - case "brotli": - file_name = base_file_name + ".jsonl.br" - case _: - file_name = base_file_name + ".jsonl" + if inputs.compression is not None: + file_name = base_file_name + f".{file_extension}.{COMPRESSION_EXTENSIONS[inputs.compression]}" + else: + file_name = base_file_name + f".{file_extension}" key = posixpath.join(key_prefix, file_name) @@ -311,6 +334,8 @@ class S3InsertInputs: kms_key_id: str | None = None batch_export_schema: BatchExportSchema | None = None endpoint_url: str | None = None + # TODO: In Python 3.11, this could be a enum.StrEnum. + file_format: str = "JSONLines" async def initialize_and_resume_multipart_upload(inputs: S3InsertInputs) -> tuple[S3MultiPartUpload, str]: @@ -451,7 +476,7 @@ async def insert_into_s3_activity(inputs: S3InsertInputs) -> int: last_uploaded_part_timestamp: str | None = None - async def worker_shutdown_handler(): + async def worker_shutdown_handler() -> None: """Handle the Worker shutting down by heart-beating our latest status.""" await activity.wait_for_worker_shutdown() logger.warn( @@ -466,50 +491,147 @@ async def worker_shutdown_handler(): asyncio.create_task(worker_shutdown_handler()) - record = None - async with s3_upload as s3_upload: - with BatchExportTemporaryFile(compression=inputs.compression) as local_results_file: + + async def flush_to_s3( + local_results_file, + records_since_last_flush: int, + bytes_since_last_flush: int, + last_inserted_at: dt.datetime, + last: bool, + ): + nonlocal last_uploaded_part_timestamp + + logger.debug( + "Uploading %s part %s containing %s records with size %s bytes", + "last " if last else "", + s3_upload.part_number + 1, + records_since_last_flush, + bytes_since_last_flush, + ) + + await s3_upload.upload_part(local_results_file) + rows_exported.add(records_since_last_flush) + bytes_exported.add(bytes_since_last_flush) + + last_uploaded_part_timestamp = str(last_inserted_at) + activity.heartbeat(last_uploaded_part_timestamp, s3_upload.to_state()) + + first_record_batch, record_iterator = peek_first_and_rewind(record_iterator) + first_record_batch = cast_record_batch_json_columns(first_record_batch) + column_names = first_record_batch.column_names + column_names.pop(column_names.index("_inserted_at")) + + schema = pa.schema( + # NOTE: For some reason, some batches set non-nullable fields as non-nullable, whereas other + # record batches have them as nullable. + # Until we figure it out, we set all fields to nullable. There are some fields we know + # are not nullable, but I'm opting for the more flexible option until we out why schemas differ + # between batches. + [field.with_nullable(True) for field in first_record_batch.select(column_names).schema] + ) + + writer = get_batch_export_writer( + inputs, + flush_callable=flush_to_s3, + max_bytes=settings.BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES, + schema=schema, + ) + + async with writer.open_temporary_file(): rows_exported = get_rows_exported_metric() bytes_exported = get_bytes_exported_metric() - async def flush_to_s3(last_uploaded_part_timestamp: str, last=False): - logger.debug( - "Uploading %s part %s containing %s records with size %s bytes", - "last " if last else "", - s3_upload.part_number + 1, - local_results_file.records_since_last_reset, - local_results_file.bytes_since_last_reset, - ) + for record_batch in record_iterator: + record_batch = cast_record_batch_json_columns(record_batch) - await s3_upload.upload_part(local_results_file) - rows_exported.add(local_results_file.records_since_last_reset) - bytes_exported.add(local_results_file.bytes_since_last_reset) + await writer.write_record_batch(record_batch) - activity.heartbeat(last_uploaded_part_timestamp, s3_upload.to_state()) + await s3_upload.complete() - for record_batch in record_iterator: - for record in record_batch.to_pylist(): - for json_column in ("properties", "person_properties", "set", "set_once"): - if (json_str := record.get(json_column, None)) is not None: - record[json_column] = json.loads(json_str) + return writer.records_total - inserted_at = record.pop("_inserted_at") - local_results_file.write_records_to_jsonl([record]) +def get_batch_export_writer( + inputs: S3InsertInputs, flush_callable: FlushCallable, max_bytes: int, schema: pa.Schema | None = None +) -> BatchExportWriter: + """Return the `BatchExportWriter` corresponding to configured `file_format`. - if local_results_file.tell() > settings.BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES: - last_uploaded_part_timestamp = str(inserted_at) - await flush_to_s3(last_uploaded_part_timestamp) - local_results_file.reset() + Raises: + UnsupportedFileFormatError: If no writer exists for given `file_format`. + """ + writer: BatchExportWriter - if local_results_file.tell() > 0 and record is not None: - last_uploaded_part_timestamp = str(inserted_at) - await flush_to_s3(last_uploaded_part_timestamp, last=True) + if inputs.file_format == "Parquet": + writer = ParquetBatchExportWriter( + max_bytes=max_bytes, + flush_callable=flush_callable, + compression=inputs.compression, + schema=schema, + ) + elif inputs.file_format == "JSONLines": + writer = JSONLBatchExportWriter( + max_bytes=settings.BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES, + flush_callable=flush_callable, + compression=inputs.compression, + ) + else: + raise UnsupportedFileFormatError(inputs.file_format, "S3") - await s3_upload.complete() + return writer + + +def cast_record_batch_json_columns( + record_batch: pa.RecordBatch, + json_columns: collections.abc.Sequence = ("properties", "person_properties", "set", "set_once"), +) -> pa.RecordBatch: + """Cast json_columns in record_batch to JsonType. + + We return a new RecordBatch with any json_columns replaced by fields casted to JsonType. + Casting is not copying the underlying array buffers, so memory usage does not increase when creating + the new array or the new record batch. + """ + column_names = set(record_batch.column_names) + intersection = column_names & set(json_columns) + + casted_arrays = [] + for array in record_batch.select(intersection): + if pa.types.is_string(array.type): + casted_array = array.cast(JsonType()) + casted_arrays.append(casted_array) + + remaining_column_names = list(column_names - intersection) + return pa.RecordBatch.from_arrays( + record_batch.select(remaining_column_names).columns + casted_arrays, + names=remaining_column_names + list(intersection), + ) + + +class JsonScalar(pa.ExtensionScalar): + """Represents a JSON binary string.""" + + def as_py(self) -> dict | None: + if self.value: + return orjson.loads(self.value.as_py().encode("utf-8")) + else: + return None + + +class JsonType(pa.ExtensionType): + """Type for JSON binary strings.""" + + def __init__(self): + super().__init__(pa.string(), "json") + + def __arrow_ext_serialize__(self): + return b"" + + @classmethod + def __arrow_ext_deserialize__(self, storage_type, serialized): + return JsonType() - return local_results_file.records_total + def __arrow_ext_scalar_class__(self): + return JsonScalar @workflow.defn(name="s3-export") @@ -572,6 +694,7 @@ async def run(self, inputs: S3BatchExportInputs): encryption=inputs.encryption, kms_key_id=inputs.kms_key_id, batch_export_schema=inputs.batch_export_schema, + file_format=inputs.file_format, ) await execute_batch_export_insert_activity( diff --git a/posthog/temporal/batch_exports/snowflake_batch_export.py b/posthog/temporal/batch_exports/snowflake_batch_export.py index be94eca89a799..9053f3e1006ad 100644 --- a/posthog/temporal/batch_exports/snowflake_batch_export.py +++ b/posthog/temporal/batch_exports/snowflake_batch_export.py @@ -18,7 +18,6 @@ from posthog.batch_exports.service import BatchExportField, BatchExportSchema, SnowflakeBatchExportInputs from posthog.temporal.batch_exports.base import PostHogWorkflow from posthog.temporal.batch_exports.batch_exports import ( - BatchExportTemporaryFile, CreateBatchExportRunInputs, UpdateBatchExportRunStatusInputs, create_export_run, @@ -32,6 +31,9 @@ get_bytes_exported_metric, get_rows_exported_metric, ) +from posthog.temporal.batch_exports.temporary_file import ( + BatchExportTemporaryFile, +) from posthog.temporal.batch_exports.utils import peek_first_and_rewind from posthog.temporal.common.clickhouse import get_client from posthog.temporal.common.logger import bind_temporal_worker_logger diff --git a/posthog/temporal/batch_exports/temporary_file.py b/posthog/temporal/batch_exports/temporary_file.py new file mode 100644 index 0000000000000..f955f45553727 --- /dev/null +++ b/posthog/temporal/batch_exports/temporary_file.py @@ -0,0 +1,528 @@ +"""This module contains a temporary file to stage data in batch exports.""" +import abc +import collections.abc +import contextlib +import csv +import datetime as dt +import gzip +import tempfile +import typing + +import brotli +import orjson +import pyarrow as pa +import pyarrow.parquet as pq + + +def json_dumps_bytes(d) -> bytes: + return orjson.dumps(d, default=str) + + +class BatchExportTemporaryFile: + """A TemporaryFile used to as an intermediate step while exporting data. + + This class does not implement the file-like interface but rather passes any calls + to the underlying tempfile.NamedTemporaryFile. We do override 'write' methods + to allow tracking bytes and records. + """ + + def __init__( + self, + mode: str = "w+b", + buffering=-1, + compression: str | None = None, + encoding: str | None = None, + newline: str | None = None, + suffix: str | None = None, + prefix: str | None = None, + dir: str | None = None, + *, + errors: str | None = None, + ): + self._file = tempfile.NamedTemporaryFile( + mode=mode, + encoding=encoding, + newline=newline, + buffering=buffering, + suffix=suffix, + prefix=prefix, + dir=dir, + errors=errors, + ) + self.compression = compression + self.bytes_total = 0 + self.records_total = 0 + self.bytes_since_last_reset = 0 + self.records_since_last_reset = 0 + self._brotli_compressor = None + + def __getattr__(self, name): + """Pass get attr to underlying tempfile.NamedTemporaryFile.""" + return self._file.__getattr__(name) + + def __enter__(self): + """Context-manager protocol enter method.""" + self._file.__enter__() + return self + + def __exit__(self, exc, value, tb): + """Context-manager protocol exit method.""" + return self._file.__exit__(exc, value, tb) + + def __iter__(self): + yield from self._file + + @property + def brotli_compressor(self): + if self._brotli_compressor is None: + self._brotli_compressor = brotli.Compressor() + return self._brotli_compressor + + def finish_brotli_compressor(self): + """Flush remaining brotli bytes.""" + # TODO: Move compression out of `BatchExportTemporaryFile` to a standard class for all writers. + if self.compression != "brotli": + raise ValueError(f"Compression is '{self.compression}', not 'brotli'") + + result = self._file.write(self.brotli_compressor.finish()) + self.bytes_total += result + self.bytes_since_last_reset += result + self._brotli_compressor = None + + def compress(self, content: bytes | str) -> bytes: + if isinstance(content, str): + encoded = content.encode("utf-8") + else: + encoded = content + + match self.compression: + case "gzip": + return gzip.compress(encoded) + case "brotli": + self.brotli_compressor.process(encoded) + return self.brotli_compressor.flush() + case None: + return encoded + case _: + raise ValueError(f"Unsupported compression: '{self.compression}'") + + def write(self, content: bytes | str): + """Write bytes to underlying file keeping track of how many bytes were written.""" + compressed_content = self.compress(content) + + if "b" in self.mode: + result = self._file.write(compressed_content) + else: + result = self._file.write(compressed_content.decode("utf-8")) + + self.bytes_total += result + self.bytes_since_last_reset += result + + return result + + def write_record_as_bytes(self, record: bytes): + result = self.write(record) + + self.records_total += 1 + self.records_since_last_reset += 1 + + return result + + def write_records_to_jsonl(self, records): + """Write records to a temporary file as JSONL.""" + if len(records) == 1: + jsonl_dump = orjson.dumps(records[0], option=orjson.OPT_APPEND_NEWLINE, default=str) + else: + jsonl_dump = b"\n".join(map(json_dumps_bytes, records)) + + result = self.write(jsonl_dump) + + self.records_total += len(records) + self.records_since_last_reset += len(records) + + return result + + def write_records_to_csv( + self, + records, + fieldnames: None | collections.abc.Sequence[str] = None, + extrasaction: typing.Literal["raise", "ignore"] = "ignore", + delimiter: str = ",", + quotechar: str = '"', + escapechar: str | None = "\\", + lineterminator: str = "\n", + quoting=csv.QUOTE_NONE, + ): + """Write records to a temporary file as CSV.""" + if len(records) == 0: + return + + if fieldnames is None: + fieldnames = list(records[0].keys()) + + writer = csv.DictWriter( + self, + fieldnames=fieldnames, + extrasaction=extrasaction, + delimiter=delimiter, + quotechar=quotechar, + escapechar=escapechar, + quoting=quoting, + lineterminator=lineterminator, + ) + writer.writerows(records) + + self.records_total += len(records) + self.records_since_last_reset += len(records) + + def write_records_to_tsv( + self, + records, + fieldnames: None | list[str] = None, + extrasaction: typing.Literal["raise", "ignore"] = "ignore", + quotechar: str = '"', + escapechar: str | None = "\\", + lineterminator: str = "\n", + quoting=csv.QUOTE_NONE, + ): + """Write records to a temporary file as TSV.""" + return self.write_records_to_csv( + records, + fieldnames=fieldnames, + extrasaction=extrasaction, + delimiter="\t", + quotechar=quotechar, + escapechar=escapechar, + quoting=quoting, + lineterminator=lineterminator, + ) + + def rewind(self): + """Rewind the file before reading it.""" + self._file.seek(0) + + def reset(self): + """Reset underlying file by truncating it. + + Also resets the tracker attributes for bytes and records since last reset. + """ + self._file.seek(0) + self._file.truncate() + + self.bytes_since_last_reset = 0 + self.records_since_last_reset = 0 + + +LastInsertedAt = dt.datetime +IsLast = bool +RecordsSinceLastFlush = int +BytesSinceLastFlush = int +FlushCallable = collections.abc.Callable[ + [BatchExportTemporaryFile, RecordsSinceLastFlush, BytesSinceLastFlush, LastInsertedAt, IsLast], + collections.abc.Awaitable[None], +] + + +class UnsupportedFileFormatError(Exception): + """Raised when a writer for an unsupported file format is requested.""" + + def __init__(self, file_format: str, destination: str): + super().__init__(f"{file_format} is not a supported format for {destination} batch exports.") + + +class BatchExportWriter(abc.ABC): + """A temporary file writer to be used by batch export workflows. + + Subclasses should define `_write_record_batch` with the particular intricacies + of the format they are writing as. + + Actual writing calls are passed to the underlying `batch_export_file`. + + Attributes: + _batch_export_file: The temporary file we are writing to. + max_bytes: Flush the temporary file with the provided `flush_callable` + upon reaching or surpassing this threshold. Keep in mind we write on a RecordBatch + per RecordBatch basis, which means the threshold will be surpassed by at most the + size of a RecordBatch before a flush occurs. + flush_callable: A callback to flush the temporary file when `max_bytes` is reached. + The temporary file will be reset after calling `flush_callable`. When calling + `flush_callable` the following positional arguments will be passed: The temporary file + that must be flushed, the number of records since the last flush, the number of bytes + since the last flush, the latest recorded `_inserted_at`, and a `bool` indicating if + this is the last flush (when exiting the context manager). + file_kwargs: Optional keyword arguments passed when initializing `_batch_export_file`. + last_inserted_at: Latest `_inserted_at` written. This attribute leaks some implementation + details, as we are assuming assume `_inserted_at` is present, as it's added to all + batch export queries. + records_total: The total number of records (not RecordBatches!) written. + records_since_last_flush: The number of records written since last flush. + bytes_total: The total number of bytes written. + bytes_since_last_flush: The number of bytes written since last flush. + """ + + def __init__( + self, + flush_callable: FlushCallable, + max_bytes: int, + file_kwargs: collections.abc.Mapping[str, typing.Any] | None = None, + ): + self.flush_callable = flush_callable + self.max_bytes = max_bytes + self.file_kwargs: collections.abc.Mapping[str, typing.Any] = file_kwargs or {} + + self._batch_export_file: BatchExportTemporaryFile | None = None + self.reset_writer_tracking() + + def reset_writer_tracking(self): + """Reset this writer's tracking state.""" + self.last_inserted_at: dt.datetime | None = None + self.records_total = 0 + self.records_since_last_flush = 0 + self.bytes_total = 0 + self.bytes_since_last_flush = 0 + + @contextlib.asynccontextmanager + async def open_temporary_file(self): + """Explicitly open the temporary file this writer is writing to. + + The underlying `BatchExportTemporaryFile` is only accessible within this context manager. This helps + us separate the lifetime of the underlying temporary file from the writer: The writer may still be + accessed even after the temporary file is closed, while on the other hand we ensure the file and all + its data is flushed and not leaked outside the context. Any relevant tracking information is copied + to the writer. + """ + self.reset_writer_tracking() + + with BatchExportTemporaryFile(**self.file_kwargs) as temp_file: + self._batch_export_file = temp_file + + try: + yield + finally: + self.track_bytes_written(temp_file) + + if self.last_inserted_at is not None and self.bytes_since_last_flush > 0: + # `bytes_since_last_flush` should be 0 unless: + # 1. The last batch wasn't flushed as it didn't reach `max_bytes`. + # 2. The last batch was flushed but there was another write after the last call to + # `write_record_batch`. For example, footer bytes. + await self.flush(self.last_inserted_at, is_last=True) + + self._batch_export_file = None + + @property + def batch_export_file(self): + """Property for underlying temporary file. + + Raises: + ValueError: if attempting to access the temporary file before it has been opened. + """ + if self._batch_export_file is None: + raise ValueError("Batch export file is closed. Did you forget to call 'open_temporary_file'?") + return self._batch_export_file + + @abc.abstractmethod + def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: + """Write a record batch to the underlying `BatchExportTemporaryFile`. + + Subclasses must override this to provide the actual implementation according to the supported + file format. + """ + pass + + def track_records_written(self, record_batch: pa.RecordBatch) -> None: + """Update this writer's state with the number of records in `record_batch`.""" + self.records_total += record_batch.num_rows + self.records_since_last_flush += record_batch.num_rows + + def track_bytes_written(self, batch_export_file: BatchExportTemporaryFile) -> None: + """Update this writer's state with the bytes in `batch_export_file`.""" + self.bytes_total = batch_export_file.bytes_total + self.bytes_since_last_flush = batch_export_file.bytes_since_last_reset + + async def write_record_batch(self, record_batch: pa.RecordBatch) -> None: + """Issue a record batch write tracking progress and flushing if required.""" + record_batch = record_batch.sort_by("_inserted_at") + last_inserted_at = record_batch.column("_inserted_at")[-1].as_py() + + column_names = record_batch.column_names + column_names.pop(column_names.index("_inserted_at")) + + self._write_record_batch(record_batch.select(column_names)) + + self.last_inserted_at = last_inserted_at + self.track_records_written(record_batch) + self.track_bytes_written(self.batch_export_file) + + if self.bytes_since_last_flush >= self.max_bytes: + await self.flush(last_inserted_at) + + async def flush(self, last_inserted_at: dt.datetime, is_last: bool = False) -> None: + """Call the provided `flush_callable` and reset underlying file. + + The underlying batch export temporary file will be reset after calling `flush_callable`. + """ + if is_last is True and self.batch_export_file.compression == "brotli": + self.batch_export_file.finish_brotli_compressor() + + self.batch_export_file.seek(0) + + await self.flush_callable( + self.batch_export_file, + self.records_since_last_flush, + self.bytes_since_last_flush, + last_inserted_at, + is_last, + ) + self.batch_export_file.reset() + + self.records_since_last_flush = 0 + self.bytes_since_last_flush = 0 + + +class JSONLBatchExportWriter(BatchExportWriter): + """A `BatchExportWriter` for JSONLines format. + + Attributes: + default: The default function to use to cast non-serializable Python objects to serializable objects. + By default, non-serializable objects will be cast to string via `str()`. + """ + + def __init__( + self, + max_bytes: int, + flush_callable: FlushCallable, + compression: None | str = None, + default: typing.Callable = str, + ): + super().__init__( + max_bytes=max_bytes, + flush_callable=flush_callable, + file_kwargs={"compression": compression}, + ) + + self.default = default + + def write(self, content: bytes) -> int: + """Write a single row of JSONL.""" + n = self.batch_export_file.write(orjson.dumps(content, default=str) + b"\n") + return n + + def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: + """Write records to a temporary file as JSONL.""" + for record in record_batch.to_pylist(): + self.write(record) + + +class CSVBatchExportWriter(BatchExportWriter): + """A `BatchExportWriter` for CSV format.""" + + def __init__( + self, + max_bytes: int, + flush_callable: FlushCallable, + field_names: collections.abc.Sequence[str], + extras_action: typing.Literal["raise", "ignore"] = "ignore", + delimiter: str = ",", + quote_char: str = '"', + escape_char: str | None = "\\", + line_terminator: str = "\n", + quoting=csv.QUOTE_NONE, + compression: str | None = None, + ): + super().__init__( + max_bytes=max_bytes, + flush_callable=flush_callable, + file_kwargs={"compression": compression}, + ) + self.field_names = field_names + self.extras_action: typing.Literal["raise", "ignore"] = extras_action + self.delimiter = delimiter + self.quote_char = quote_char + self.escape_char = escape_char + self.line_terminator = line_terminator + self.quoting = quoting + + self._csv_writer: csv.DictWriter | None = None + + @property + def csv_writer(self) -> csv.DictWriter: + if self._csv_writer is None: + self._csv_writer = csv.DictWriter( + self.batch_export_file, + fieldnames=self.field_names, + extrasaction=self.extras_action, + delimiter=self.delimiter, + quotechar=self.quote_char, + escapechar=self.escape_char, + quoting=self.quoting, + lineterminator=self.line_terminator, + ) + + return self._csv_writer + + def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: + """Write records to a temporary file as CSV.""" + self.csv_writer.writerows(record_batch.to_pylist()) + + +class ParquetBatchExportWriter(BatchExportWriter): + """A `BatchExportWriter` for Apache Parquet format. + + We utilize and wrap a `pyarrow.parquet.ParquetWriter` to do the actual writing. We default to their + defaults for most parameters; however this class could be extended with more attributes to pass along + to `pyarrow.parquet.ParquetWriter`. + + See the pyarrow docs for more details on what parameters can the writer be configured with: + https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetWriter.html + + In contrast to other writers, instead of us handling compression we let `pyarrow.parquet.ParquetWriter` + handle it, so `BatchExportTemporaryFile` is always initialized with `compression=None`. + + Attributes: + schema: The schema used by the Parquet file. Should match the schema of written RecordBatches. + compression: Compression codec passed to underlying `pyarrow.parquet.ParquetWriter`. + """ + + def __init__( + self, + max_bytes: int, + flush_callable: FlushCallable, + schema: pa.Schema, + compression: str | None = "snappy", + ): + super().__init__( + max_bytes=max_bytes, + flush_callable=flush_callable, + file_kwargs={"compression": None}, # ParquetWriter handles compression + ) + self.schema = schema + self.compression = compression + + self._parquet_writer: pq.ParquetWriter | None = None + + @property + def parquet_writer(self) -> pq.ParquetWriter: + if self._parquet_writer is None: + self._parquet_writer = pq.ParquetWriter( + self.batch_export_file, + schema=self.schema, + compression="none" if self.compression is None else self.compression, + ) + return self._parquet_writer + + @contextlib.asynccontextmanager + async def open_temporary_file(self): + """Ensure underlying Parquet writer is closed before flushing and closing temporary file.""" + async with super().open_temporary_file(): + try: + yield + finally: + if self._parquet_writer is not None: + self._parquet_writer.writer.close() + self._parquet_writer = None + + def _write_record_batch(self, record_batch: pa.RecordBatch) -> None: + """Write records to a temporary file as Parquet.""" + + self.parquet_writer.write_batch(record_batch.select(self.parquet_writer.schema.names)) diff --git a/posthog/temporal/tests/batch_exports/test_batch_exports.py b/posthog/temporal/tests/batch_exports/test_batch_exports.py index 0afbfcabb71cb..756c07e442e4f 100644 --- a/posthog/temporal/tests/batch_exports/test_batch_exports.py +++ b/posthog/temporal/tests/batch_exports/test_batch_exports.py @@ -1,6 +1,4 @@ -import csv import datetime as dt -import io import json import operator from random import randint @@ -9,11 +7,9 @@ from django.test import override_settings from posthog.temporal.batch_exports.batch_exports import ( - BatchExportTemporaryFile, get_data_interval, get_rows_count, iter_records, - json_dumps_bytes, ) from posthog.temporal.tests.utils.events import generate_test_events_in_clickhouse @@ -558,181 +554,3 @@ def test_get_data_interval(interval, data_interval_end, expected): """Test get_data_interval returns the expected data interval tuple.""" result = get_data_interval(interval, data_interval_end) assert result == expected - - -@pytest.mark.parametrize( - "to_write", - [ - (b"",), - (b"", b""), - (b"12345",), - (b"12345", b"12345"), - (b"abbcccddddeeeee",), - (b"abbcccddddeeeee", b"abbcccddddeeeee"), - ], -) -def test_batch_export_temporary_file_tracks_bytes(to_write): - """Test the bytes written by BatchExportTemporaryFile match expected.""" - with BatchExportTemporaryFile() as be_file: - for content in to_write: - be_file.write(content) - - assert be_file.bytes_total == sum(len(content) for content in to_write) - assert be_file.bytes_since_last_reset == sum(len(content) for content in to_write) - - be_file.reset() - - assert be_file.bytes_total == sum(len(content) for content in to_write) - assert be_file.bytes_since_last_reset == 0 - - -TEST_RECORDS = [ - [], - [ - {"id": "record-1", "property": "value", "property_int": 1}, - {"id": "record-2", "property": "another-value", "property_int": 2}, - { - "id": "record-3", - "property": {"id": "nested-record", "property": "nested-value"}, - "property_int": 3, - }, - ], -] - - -@pytest.mark.parametrize( - "records", - TEST_RECORDS, -) -def test_batch_export_temporary_file_write_records_to_jsonl(records): - """Test JSONL records written by BatchExportTemporaryFile match expected.""" - jsonl_dump = b"\n".join(map(json_dumps_bytes, records)) - - with BatchExportTemporaryFile() as be_file: - be_file.write_records_to_jsonl(records) - - assert be_file.bytes_total == len(jsonl_dump) - assert be_file.bytes_since_last_reset == len(jsonl_dump) - assert be_file.records_total == len(records) - assert be_file.records_since_last_reset == len(records) - - be_file.seek(0) - lines = be_file.readlines() - assert len(lines) == len(records) - - for line_index, jsonl_record in enumerate(lines): - json_loaded = json.loads(jsonl_record) - assert json_loaded == records[line_index] - - be_file.reset() - - assert be_file.bytes_total == len(jsonl_dump) - assert be_file.bytes_since_last_reset == 0 - assert be_file.records_total == len(records) - assert be_file.records_since_last_reset == 0 - - -@pytest.mark.parametrize( - "records", - TEST_RECORDS, -) -def test_batch_export_temporary_file_write_records_to_csv(records): - """Test CSV written by BatchExportTemporaryFile match expected.""" - in_memory_file_obj = io.StringIO() - writer = csv.DictWriter( - in_memory_file_obj, - fieldnames=records[0].keys() if len(records) > 0 else [], - delimiter=",", - quotechar='"', - escapechar="\\", - lineterminator="\n", - quoting=csv.QUOTE_NONE, - ) - writer.writerows(records) - - with BatchExportTemporaryFile(mode="w+") as be_file: - be_file.write_records_to_csv(records) - - assert be_file.bytes_total == in_memory_file_obj.tell() - assert be_file.bytes_since_last_reset == in_memory_file_obj.tell() - assert be_file.records_total == len(records) - assert be_file.records_since_last_reset == len(records) - - be_file.seek(0) - reader = csv.reader( - be_file._file, - delimiter=",", - quotechar='"', - escapechar="\\", - quoting=csv.QUOTE_NONE, - ) - - rows = [row for row in reader] - assert len(rows) == len(records) - - for row_index, csv_record in enumerate(rows): - for value_index, value in enumerate(records[row_index].values()): - # Everything returned by csv.reader is a str. - # This means type information is lost when writing to CSV - # but this just a limitation of the format. - assert csv_record[value_index] == str(value) - - be_file.reset() - - assert be_file.bytes_total == in_memory_file_obj.tell() - assert be_file.bytes_since_last_reset == 0 - assert be_file.records_total == len(records) - assert be_file.records_since_last_reset == 0 - - -@pytest.mark.parametrize( - "records", - TEST_RECORDS, -) -def test_batch_export_temporary_file_write_records_to_tsv(records): - """Test TSV written by BatchExportTemporaryFile match expected.""" - in_memory_file_obj = io.StringIO() - writer = csv.DictWriter( - in_memory_file_obj, - fieldnames=records[0].keys() if len(records) > 0 else [], - delimiter="\t", - quotechar='"', - escapechar="\\", - lineterminator="\n", - quoting=csv.QUOTE_NONE, - ) - writer.writerows(records) - - with BatchExportTemporaryFile(mode="w+") as be_file: - be_file.write_records_to_tsv(records) - - assert be_file.bytes_total == in_memory_file_obj.tell() - assert be_file.bytes_since_last_reset == in_memory_file_obj.tell() - assert be_file.records_total == len(records) - assert be_file.records_since_last_reset == len(records) - - be_file.seek(0) - reader = csv.reader( - be_file._file, - delimiter="\t", - quotechar='"', - escapechar="\\", - quoting=csv.QUOTE_NONE, - ) - - rows = [row for row in reader] - assert len(rows) == len(records) - - for row_index, csv_record in enumerate(rows): - for value_index, value in enumerate(records[row_index].values()): - # Everything returned by csv.reader is a str. - # This means type information is lost when writing to CSV - # but this just a limitation of the format. - assert csv_record[value_index] == str(value) - - be_file.reset() - - assert be_file.bytes_total == in_memory_file_obj.tell() - assert be_file.bytes_since_last_reset == 0 - assert be_file.records_total == len(records) - assert be_file.records_since_last_reset == 0 diff --git a/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py b/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py index e04e345d11245..e6583d049e2a8 100644 --- a/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py +++ b/posthog/temporal/tests/batch_exports/test_s3_batch_export_workflow.py @@ -10,10 +10,12 @@ import aioboto3 import botocore.exceptions import brotli +import pyarrow.parquet as pq import pytest import pytest_asyncio from django.conf import settings from django.test import override_settings +from pyarrow import fs from temporalio import activity from temporalio.client import WorkflowFailureError from temporalio.common import RetryPolicy @@ -27,6 +29,7 @@ update_export_run_status, ) from posthog.temporal.batch_exports.s3_batch_export import ( + FILE_FORMAT_EXTENSIONS, HeartbeatDetails, S3BatchExportInputs, S3BatchExportWorkflow, @@ -107,6 +110,15 @@ def s3_key_prefix(): return f"posthog-events-{str(uuid4())}" +@pytest.fixture +def file_format(request) -> str: + """S3 file format.""" + try: + return request.param + except AttributeError: + return f"JSONLines" + + async def delete_all_from_s3(minio_client, bucket_name: str, key_prefix: str): """Delete all objects in bucket_name under key_prefix.""" response = await minio_client.list_objects_v2(Bucket=bucket_name, Prefix=key_prefix) @@ -138,6 +150,61 @@ async def minio_client(bucket_name): await minio_client.delete_bucket(Bucket=bucket_name) +async def read_parquet_from_s3(bucket_name: str, key: str, json_columns) -> list: + async with aioboto3.Session().client("sts") as sts: + try: + await sts.get_caller_identity() + except botocore.exceptions.NoCredentialsError: + s3 = fs.S3FileSystem( + access_key="object_storage_root_user", + secret_key="object_storage_root_password", + endpoint_override=settings.OBJECT_STORAGE_ENDPOINT, + ) + + else: + if os.getenv("S3_TEST_BUCKET") is not None: + s3 = fs.S3FileSystem() + else: + s3 = fs.S3FileSystem( + access_key="object_storage_root_user", + secret_key="object_storage_root_password", + endpoint_override=settings.OBJECT_STORAGE_ENDPOINT, + ) + + table = pq.read_table(f"{bucket_name}/{key}", filesystem=s3) + + parquet_data = [] + for batch in table.to_batches(): + for record in batch.to_pylist(): + casted_record = {} + for k, v in record.items(): + if isinstance(v, dt.datetime): + # We read data from clickhouse as string, but parquet already casts them as dates. + # To facilitate comparison, we isoformat the dates. + casted_record[k] = v.isoformat() + elif k in json_columns and v is not None: + # Parquet doesn't have a variable map type, so JSON fields are just strings. + casted_record[k] = json.loads(v) + else: + casted_record[k] = v + parquet_data.append(casted_record) + + return parquet_data + + +def read_s3_data_as_json(data: bytes, compression: str | None) -> list: + match compression: + case "gzip": + data = gzip.decompress(data) + case "brotli": + data = brotli.decompress(data) + case _: + pass + + json_data = [json.loads(line) for line in data.decode("utf-8").split("\n") if line] + return json_data + + async def assert_clickhouse_records_in_s3( s3_compatible_client, clickhouse_client: ClickHouseClient, @@ -150,6 +217,7 @@ async def assert_clickhouse_records_in_s3( include_events: list[str] | None = None, batch_export_schema: BatchExportSchema | None = None, compression: str | None = None, + file_format: str = "JSONLines", ): """Assert ClickHouse records are written to JSON in key_prefix in S3 bucket_name. @@ -175,28 +243,24 @@ async def assert_clickhouse_records_in_s3( # Get the object. key = objects["Contents"][0].get("Key") assert key - s3_object = await s3_compatible_client.get_object(Bucket=bucket_name, Key=key) - data = await s3_object["Body"].read() - # Check that the data is correct. - match compression: - case "gzip": - data = gzip.decompress(data) - case "brotli": - data = brotli.decompress(data) - case _: - pass + json_columns = ("properties", "person_properties", "set", "set_once") - json_data = [json.loads(line) for line in data.decode("utf-8").split("\n") if line] - # Pull out the fields we inserted only + if file_format == "Parquet": + s3_data = await read_parquet_from_s3(bucket_name, key, json_columns) + + elif file_format == "JSONLines": + s3_object = await s3_compatible_client.get_object(Bucket=bucket_name, Key=key) + data = await s3_object["Body"].read() + s3_data = read_s3_data_as_json(data, compression) + else: + raise ValueError(f"Unsupported file format: {file_format}") if batch_export_schema is not None: schema_column_names = [field["alias"] for field in batch_export_schema["fields"]] else: schema_column_names = [field["alias"] for field in s3_default_fields()] - json_columns = ("properties", "person_properties", "set", "set_once") - expected_records = [] for record_batch in iter_records( client=clickhouse_client, @@ -225,9 +289,9 @@ async def assert_clickhouse_records_in_s3( expected_records.append(expected_record) - assert len(json_data) == len(expected_records) - assert json_data[0] == expected_records[0] - assert json_data == expected_records + assert len(s3_data) == len(expected_records) + assert s3_data[0] == expected_records[0] + assert s3_data == expected_records TEST_S3_SCHEMAS: list[BatchExportSchema | None] = [ @@ -255,6 +319,7 @@ async def assert_clickhouse_records_in_s3( @pytest.mark.parametrize("compression", [None, "gzip", "brotli"], indirect=True) @pytest.mark.parametrize("exclude_events", [None, ["test-exclude"]], indirect=True) @pytest.mark.parametrize("batch_export_schema", TEST_S3_SCHEMAS) +@pytest.mark.parametrize("file_format", FILE_FORMAT_EXTENSIONS.keys()) async def test_insert_into_s3_activity_puts_data_into_s3( clickhouse_client, bucket_name, @@ -262,6 +327,7 @@ async def test_insert_into_s3_activity_puts_data_into_s3( activity_environment, compression, exclude_events, + file_format, batch_export_schema: BatchExportSchema | None, ): """Test that the insert_into_s3_activity function ends up with data into S3. @@ -339,12 +405,15 @@ async def test_insert_into_s3_activity_puts_data_into_s3( compression=compression, exclude_events=exclude_events, batch_export_schema=batch_export_schema, + file_format=file_format, ) with override_settings( BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2 ): # 5MB, the minimum for Multipart uploads - await activity_environment.run(insert_into_s3_activity, insert_inputs) + records_total = await activity_environment.run(insert_into_s3_activity, insert_inputs) + + assert records_total == 10005 await assert_clickhouse_records_in_s3( s3_compatible_client=minio_client, @@ -358,6 +427,7 @@ async def test_insert_into_s3_activity_puts_data_into_s3( exclude_events=exclude_events, include_events=None, compression=compression, + file_format=file_format, ) @@ -371,6 +441,7 @@ async def s3_batch_export( exclude_events, temporal_client, encryption, + file_format, ): destination_data = { "type": "S3", @@ -385,6 +456,7 @@ async def s3_batch_export( "exclude_events": exclude_events, "encryption": encryption, "kms_key_id": os.getenv("S3_TEST_KMS_KEY_ID") if encryption == "aws:kms" else None, + "file_format": file_format, }, } @@ -410,6 +482,7 @@ async def s3_batch_export( @pytest.mark.parametrize("compression", [None, "gzip", "brotli"], indirect=True) @pytest.mark.parametrize("exclude_events", [None, ["test-exclude"]], indirect=True) @pytest.mark.parametrize("batch_export_schema", TEST_S3_SCHEMAS) +@pytest.mark.parametrize("file_format", FILE_FORMAT_EXTENSIONS.keys(), indirect=True) async def test_s3_export_workflow_with_minio_bucket( clickhouse_client, minio_client, @@ -421,6 +494,7 @@ async def test_s3_export_workflow_with_minio_bucket( exclude_events, s3_key_prefix, batch_export_schema, + file_format, ): """Test S3BatchExport Workflow end-to-end by using a local MinIO bucket instead of S3. @@ -508,6 +582,7 @@ async def test_s3_export_workflow_with_minio_bucket( batch_export_schema=batch_export_schema, exclude_events=exclude_events, compression=compression, + file_format=file_format, ) @@ -537,6 +612,7 @@ async def s3_client(bucket_name, s3_key_prefix): @pytest.mark.parametrize("encryption", [None, "AES256", "aws:kms"], indirect=True) @pytest.mark.parametrize("bucket_name", [os.getenv("S3_TEST_BUCKET")], indirect=True) @pytest.mark.parametrize("batch_export_schema", TEST_S3_SCHEMAS) +@pytest.mark.parametrize("file_format", FILE_FORMAT_EXTENSIONS.keys(), indirect=True) async def test_s3_export_workflow_with_s3_bucket( s3_client, clickhouse_client, @@ -549,6 +625,7 @@ async def test_s3_export_workflow_with_s3_bucket( exclude_events, ateam, batch_export_schema, + file_format, ): """Test S3 Export Workflow end-to-end by using an S3 bucket. @@ -646,6 +723,7 @@ async def test_s3_export_workflow_with_s3_bucket( exclude_events=exclude_events, include_events=None, compression=compression, + file_format=file_format, ) @@ -1206,6 +1284,49 @@ async def never_finish_activity(_: S3InsertInputs) -> str: ), "nested/prefix/2023-01-01 00:00:00-2023-01-01 01:00:00.jsonl.br", ), + ( + S3InsertInputs( + prefix="/nested/prefix/", + data_interval_start="2023-01-01 00:00:00", + data_interval_end="2023-01-01 01:00:00", + file_format="Parquet", + compression="snappy", + **base_inputs, # type: ignore + ), + "nested/prefix/2023-01-01 00:00:00-2023-01-01 01:00:00.parquet.sz", + ), + ( + S3InsertInputs( + prefix="/nested/prefix/", + data_interval_start="2023-01-01 00:00:00", + data_interval_end="2023-01-01 01:00:00", + file_format="Parquet", + **base_inputs, # type: ignore + ), + "nested/prefix/2023-01-01 00:00:00-2023-01-01 01:00:00.parquet", + ), + ( + S3InsertInputs( + prefix="/nested/prefix/", + data_interval_start="2023-01-01 00:00:00", + data_interval_end="2023-01-01 01:00:00", + compression="gzip", + file_format="Parquet", + **base_inputs, # type: ignore + ), + "nested/prefix/2023-01-01 00:00:00-2023-01-01 01:00:00.parquet.gz", + ), + ( + S3InsertInputs( + prefix="/nested/prefix/", + data_interval_start="2023-01-01 00:00:00", + data_interval_end="2023-01-01 01:00:00", + compression="brotli", + file_format="Parquet", + **base_inputs, # type: ignore + ), + "nested/prefix/2023-01-01 00:00:00-2023-01-01 01:00:00.parquet.br", + ), ], ) def test_get_s3_key(inputs, expected): @@ -1271,7 +1392,7 @@ def assert_heartbeat_details(*details): endpoint_url=settings.OBJECT_STORAGE_ENDPOINT, ) - with override_settings(BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES=5 * 1024**2): + with override_settings(BATCH_EXPORT_S3_UPLOAD_CHUNK_SIZE_BYTES=1, CLICKHOUSE_MAX_BLOCK_SIZE_DEFAULT=1): await activity_environment.run(insert_into_s3_activity, insert_inputs) # This checks that the assert_heartbeat_details function was actually called. diff --git a/posthog/temporal/tests/batch_exports/test_temporary_file.py b/posthog/temporal/tests/batch_exports/test_temporary_file.py new file mode 100644 index 0000000000000..4fd7e69c0c12f --- /dev/null +++ b/posthog/temporal/tests/batch_exports/test_temporary_file.py @@ -0,0 +1,389 @@ +import csv +import datetime as dt +import io +import json + +import pyarrow as pa +import pyarrow.parquet as pq +import pytest + +from posthog.temporal.batch_exports.temporary_file import ( + BatchExportTemporaryFile, + CSVBatchExportWriter, + JSONLBatchExportWriter, + ParquetBatchExportWriter, + json_dumps_bytes, +) + + +@pytest.mark.parametrize( + "to_write", + [ + (b"",), + (b"", b""), + (b"12345",), + (b"12345", b"12345"), + (b"abbcccddddeeeee",), + (b"abbcccddddeeeee", b"abbcccddddeeeee"), + ], +) +def test_batch_export_temporary_file_tracks_bytes(to_write): + """Test the bytes written by BatchExportTemporaryFile match expected.""" + with BatchExportTemporaryFile() as be_file: + for content in to_write: + be_file.write(content) + + assert be_file.bytes_total == sum(len(content) for content in to_write) + assert be_file.bytes_since_last_reset == sum(len(content) for content in to_write) + + be_file.reset() + + assert be_file.bytes_total == sum(len(content) for content in to_write) + assert be_file.bytes_since_last_reset == 0 + + +TEST_RECORDS = [ + [], + [ + {"id": "record-1", "property": "value", "property_int": 1}, + {"id": "record-2", "property": "another-value", "property_int": 2}, + { + "id": "record-3", + "property": {"id": "nested-record", "property": "nested-value"}, + "property_int": 3, + }, + ], +] + + +@pytest.mark.parametrize( + "records", + TEST_RECORDS, +) +def test_batch_export_temporary_file_write_records_to_jsonl(records): + """Test JSONL records written by BatchExportTemporaryFile match expected.""" + jsonl_dump = b"\n".join(map(json_dumps_bytes, records)) + + with BatchExportTemporaryFile() as be_file: + be_file.write_records_to_jsonl(records) + + assert be_file.bytes_total == len(jsonl_dump) + assert be_file.bytes_since_last_reset == len(jsonl_dump) + assert be_file.records_total == len(records) + assert be_file.records_since_last_reset == len(records) + + be_file.seek(0) + lines = be_file.readlines() + assert len(lines) == len(records) + + for line_index, jsonl_record in enumerate(lines): + json_loaded = json.loads(jsonl_record) + assert json_loaded == records[line_index] + + be_file.reset() + + assert be_file.bytes_total == len(jsonl_dump) + assert be_file.bytes_since_last_reset == 0 + assert be_file.records_total == len(records) + assert be_file.records_since_last_reset == 0 + + +@pytest.mark.parametrize( + "records", + TEST_RECORDS, +) +def test_batch_export_temporary_file_write_records_to_csv(records): + """Test CSV written by BatchExportTemporaryFile match expected.""" + in_memory_file_obj = io.StringIO() + writer = csv.DictWriter( + in_memory_file_obj, + fieldnames=records[0].keys() if len(records) > 0 else [], + delimiter=",", + quotechar='"', + escapechar="\\", + lineterminator="\n", + quoting=csv.QUOTE_NONE, + ) + writer.writerows(records) + + with BatchExportTemporaryFile(mode="w+") as be_file: + be_file.write_records_to_csv(records) + + assert be_file.bytes_total == in_memory_file_obj.tell() + assert be_file.bytes_since_last_reset == in_memory_file_obj.tell() + assert be_file.records_total == len(records) + assert be_file.records_since_last_reset == len(records) + + be_file.seek(0) + reader = csv.reader( + be_file._file, + delimiter=",", + quotechar='"', + escapechar="\\", + quoting=csv.QUOTE_NONE, + ) + + rows = [row for row in reader] + assert len(rows) == len(records) + + for row_index, csv_record in enumerate(rows): + for value_index, value in enumerate(records[row_index].values()): + # Everything returned by csv.reader is a str. + # This means type information is lost when writing to CSV + # but this just a limitation of the format. + assert csv_record[value_index] == str(value) + + be_file.reset() + + assert be_file.bytes_total == in_memory_file_obj.tell() + assert be_file.bytes_since_last_reset == 0 + assert be_file.records_total == len(records) + assert be_file.records_since_last_reset == 0 + + +@pytest.mark.parametrize( + "records", + TEST_RECORDS, +) +def test_batch_export_temporary_file_write_records_to_tsv(records): + """Test TSV written by BatchExportTemporaryFile match expected.""" + in_memory_file_obj = io.StringIO() + writer = csv.DictWriter( + in_memory_file_obj, + fieldnames=records[0].keys() if len(records) > 0 else [], + delimiter="\t", + quotechar='"', + escapechar="\\", + lineterminator="\n", + quoting=csv.QUOTE_NONE, + ) + writer.writerows(records) + + with BatchExportTemporaryFile(mode="w+") as be_file: + be_file.write_records_to_tsv(records) + + assert be_file.bytes_total == in_memory_file_obj.tell() + assert be_file.bytes_since_last_reset == in_memory_file_obj.tell() + assert be_file.records_total == len(records) + assert be_file.records_since_last_reset == len(records) + + be_file.seek(0) + reader = csv.reader( + be_file._file, + delimiter="\t", + quotechar='"', + escapechar="\\", + quoting=csv.QUOTE_NONE, + ) + + rows = [row for row in reader] + assert len(rows) == len(records) + + for row_index, csv_record in enumerate(rows): + for value_index, value in enumerate(records[row_index].values()): + # Everything returned by csv.reader is a str. + # This means type information is lost when writing to CSV + # but this just a limitation of the format. + assert csv_record[value_index] == str(value) + + be_file.reset() + + assert be_file.bytes_total == in_memory_file_obj.tell() + assert be_file.bytes_since_last_reset == 0 + assert be_file.records_total == len(records) + assert be_file.records_since_last_reset == 0 + + +TEST_RECORD_BATCHES = [ + pa.RecordBatch.from_pydict( + { + "event": pa.array(["test-event-0", "test-event-1", "test-event-2"]), + "properties": pa.array(['{"prop_0": 1, "prop_1": 2}', "{}", "null"]), + "_inserted_at": pa.array([0, 1, 2]), + } + ) +] + + +@pytest.mark.parametrize( + "record_batch", + TEST_RECORD_BATCHES, +) +@pytest.mark.asyncio +async def test_jsonl_writer_writes_record_batches(record_batch): + """Test record batches are written as valid JSONL.""" + in_memory_file_obj = io.BytesIO() + inserted_ats_seen = [] + + async def store_in_memory_on_flush( + batch_export_file, records_since_last_flush, bytes_since_last_flush, last_inserted_at, is_last + ): + in_memory_file_obj.write(batch_export_file.read()) + inserted_ats_seen.append(last_inserted_at) + + writer = JSONLBatchExportWriter(max_bytes=1, flush_callable=store_in_memory_on_flush) + + record_batch = record_batch.sort_by("_inserted_at") + async with writer.open_temporary_file(): + await writer.write_record_batch(record_batch) + + lines = in_memory_file_obj.readlines() + for index, line in enumerate(lines): + written_jsonl = json.loads(line) + + single_record_batch = record_batch.slice(offset=index, length=1) + expected_jsonl = single_record_batch.to_pylist()[0] + + assert "_inserted_at" not in written_jsonl + assert written_jsonl == expected_jsonl + + assert inserted_ats_seen == [record_batch.column("_inserted_at")[-1].as_py()] + + +@pytest.mark.parametrize( + "record_batch", + TEST_RECORD_BATCHES, +) +@pytest.mark.asyncio +async def test_csv_writer_writes_record_batches(record_batch): + """Test record batches are written as valid CSV.""" + in_memory_file_obj = io.StringIO() + inserted_ats_seen = [] + + async def store_in_memory_on_flush( + batch_export_file, records_since_last_flush, bytes_since_last_flush, last_inserted_at, is_last + ): + in_memory_file_obj.write(batch_export_file.read().decode("utf-8")) + inserted_ats_seen.append(last_inserted_at) + + schema_columns = [column_name for column_name in record_batch.column_names if column_name != "_inserted_at"] + writer = CSVBatchExportWriter(max_bytes=1, field_names=schema_columns, flush_callable=store_in_memory_on_flush) + + record_batch = record_batch.sort_by("_inserted_at") + async with writer.open_temporary_file(): + await writer.write_record_batch(record_batch) + + reader = csv.reader( + in_memory_file_obj, + delimiter=",", + quotechar='"', + escapechar="\\", + quoting=csv.QUOTE_NONE, + ) + for index, written_csv_row in enumerate(reader): + single_record_batch = record_batch.slice(offset=index, length=1) + expected_csv = single_record_batch.to_pylist()[0] + + assert "_inserted_at" not in written_csv_row + assert written_csv_row == expected_csv + + assert inserted_ats_seen == [record_batch.column("_inserted_at")[-1].as_py()] + + +@pytest.mark.parametrize( + "record_batch", + TEST_RECORD_BATCHES, +) +@pytest.mark.asyncio +async def test_parquet_writer_writes_record_batches(record_batch): + """Test record batches are written as valid Parquet.""" + in_memory_file_obj = io.BytesIO() + inserted_ats_seen = [] + + async def store_in_memory_on_flush( + batch_export_file, records_since_last_flush, bytes_since_last_flush, last_inserted_at, is_last + ): + in_memory_file_obj.write(batch_export_file.read()) + inserted_ats_seen.append(last_inserted_at) + + schema_columns = [column_name for column_name in record_batch.column_names if column_name != "_inserted_at"] + + writer = ParquetBatchExportWriter( + max_bytes=1, + flush_callable=store_in_memory_on_flush, + schema=record_batch.select(schema_columns).schema, + ) + + record_batch = record_batch.sort_by("_inserted_at") + async with writer.open_temporary_file(): + await writer.write_record_batch(record_batch) + + written_parquet = pq.read_table(in_memory_file_obj) + + for index, written_row_as_dict in enumerate(written_parquet.to_pylist()): + single_record_batch = record_batch.slice(offset=index, length=1) + expected_row_as_dict = single_record_batch.select(schema_columns).to_pylist()[0] + + assert "_inserted_at" not in written_row_as_dict + assert written_row_as_dict == expected_row_as_dict + + # NOTE: Parquet gets flushed twice due to the extra flush at the end for footer bytes, so our mock function + # will see this value twice. + assert inserted_ats_seen == [ + record_batch.column("_inserted_at")[-1].as_py(), + record_batch.column("_inserted_at")[-1].as_py(), + ] + + +@pytest.mark.parametrize( + "record_batch", + TEST_RECORD_BATCHES, +) +@pytest.mark.asyncio +async def test_writing_out_of_scope_of_temporary_file_raises(record_batch): + """Test attempting a write out of temporary file scope raises a `ValueError`.""" + + async def do_nothing(*args, **kwargs): + pass + + schema_columns = [column_name for column_name in record_batch.column_names if column_name != "_inserted_at"] + writer = ParquetBatchExportWriter( + max_bytes=10, + flush_callable=do_nothing, + schema=record_batch.select(schema_columns).schema, + ) + + async with writer.open_temporary_file(): + pass + + with pytest.raises(ValueError, match="Batch export file is closed"): + await writer.write_record_batch(record_batch) + + +@pytest.mark.parametrize( + "record_batch", + TEST_RECORD_BATCHES, +) +@pytest.mark.asyncio +async def test_flushing_parquet_writer_resets_underlying_file(record_batch): + """Test flushing a writer resets underlying file.""" + flush_counter = 0 + + async def track_flushes(*args, **kwargs): + nonlocal flush_counter + flush_counter += 1 + + schema_columns = [column_name for column_name in record_batch.column_names if column_name != "_inserted_at"] + writer = ParquetBatchExportWriter( + max_bytes=10000000, + flush_callable=track_flushes, + schema=record_batch.select(schema_columns).schema, + ) + + async with writer.open_temporary_file(): + await writer.write_record_batch(record_batch) + + assert writer.batch_export_file.tell() > 0 + assert writer.bytes_since_last_flush > 0 + assert writer.bytes_since_last_flush == writer.batch_export_file.bytes_since_last_reset + assert writer.records_since_last_flush == record_batch.num_rows + + await writer.flush(dt.datetime.now()) + + assert flush_counter == 1 + assert writer.batch_export_file.tell() == 0 + assert writer.bytes_since_last_flush == 0 + assert writer.bytes_since_last_flush == writer.batch_export_file.bytes_since_last_reset + assert writer.records_since_last_flush == 0 + + assert flush_counter == 2