Skip to content

Commit

Permalink
refactor: Support for multiple file formats in S3 batch exports (#20979)
Browse files Browse the repository at this point in the history
* refactor: Support for multiple file formats in batch exports

* refactor: Prefer composition over inheritance

* refactor: More clearly separate writer from temporary file

We now should be more explicit about what is the context in which the
batch export temporary file is alive. The writer outlives this
context, so it can be used by callers to, for example, check how many
records were written.

* test: More parquet testing

* Update query snapshots

* fix: Typing

* refactor: Move temporary file to new module

* test: Add writer classes tests and docstrings

* feat: Add new type aliases and docstrings for FlushCallable

* refactor: Get rid of ensure close method

* fix: Use proper 'none' compression

* refactor: Cover all possible file formats with FILE_FORMAT_EXTENSIONS.keys()

* test: Also check if bucket name is set to use S3

* feat: Typing and docstring for get_batch_export_writer

---------

Co-authored-by: github-actions <41898282+github-actions[bot]@users.noreply.github.com>
  • Loading branch information
tomasfarias and github-actions[bot] authored Mar 22, 2024
1 parent d5cc9e4 commit 7c0258c
Show file tree
Hide file tree
Showing 12 changed files with 1,238 additions and 451 deletions.
1 change: 1 addition & 0 deletions posthog/batch_exports/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class S3BatchExportInputs:
kms_key_id: str | None = None
batch_export_schema: BatchExportSchema | None = None
endpoint_url: str | None = None
file_format: str = "JSONLines"


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
FROM events LEFT JOIN (
SELECT person_static_cohort.person_id AS cohort_person_id, 1 AS matched, person_static_cohort.cohort_id AS cohort_id
FROM person_static_cohort
WHERE and(equals(person_static_cohort.team_id, 420), in(person_static_cohort.cohort_id, [12]))) AS __in_cohort ON equals(__in_cohort.cohort_person_id, events.person_id)
WHERE and(equals(person_static_cohort.team_id, 420), in(person_static_cohort.cohort_id, [11]))) AS __in_cohort ON equals(__in_cohort.cohort_person_id, events.person_id)
WHERE and(equals(events.team_id, 420), 1, ifNull(equals(__in_cohort.matched, 1), 0))
LIMIT 100
SETTINGS readonly=2, max_execution_time=60, allow_experimental_object_type=1
Expand All @@ -42,7 +42,7 @@
FROM events LEFT JOIN (
SELECT person_id AS cohort_person_id, 1 AS matched, cohort_id
FROM static_cohort_people
WHERE in(cohort_id, [12])) AS __in_cohort ON equals(__in_cohort.cohort_person_id, person_id)
WHERE in(cohort_id, [11])) AS __in_cohort ON equals(__in_cohort.cohort_person_id, person_id)
WHERE and(1, equals(__in_cohort.matched, 1))
LIMIT 100
'''
Expand All @@ -55,7 +55,7 @@
FROM events LEFT JOIN (
SELECT person_static_cohort.person_id AS cohort_person_id, 1 AS matched, person_static_cohort.cohort_id AS cohort_id
FROM person_static_cohort
WHERE and(equals(person_static_cohort.team_id, 420), in(person_static_cohort.cohort_id, [13]))) AS __in_cohort ON equals(__in_cohort.cohort_person_id, events.person_id)
WHERE and(equals(person_static_cohort.team_id, 420), in(person_static_cohort.cohort_id, [12]))) AS __in_cohort ON equals(__in_cohort.cohort_person_id, events.person_id)
WHERE and(equals(events.team_id, 420), 1, ifNull(equals(__in_cohort.matched, 1), 0))
LIMIT 100
SETTINGS readonly=2, max_execution_time=60, allow_experimental_object_type=1
Expand All @@ -66,7 +66,7 @@
FROM events LEFT JOIN (
SELECT person_id AS cohort_person_id, 1 AS matched, cohort_id
FROM static_cohort_people
WHERE in(cohort_id, [13])) AS __in_cohort ON equals(__in_cohort.cohort_person_id, person_id)
WHERE in(cohort_id, [12])) AS __in_cohort ON equals(__in_cohort.cohort_person_id, person_id)
WHERE and(1, equals(__in_cohort.matched, 1))
LIMIT 100
'''
Expand Down
201 changes: 0 additions & 201 deletions posthog/temporal/batch_exports/batch_exports.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import collections.abc
import csv
import dataclasses
import datetime as dt
import gzip
import tempfile
import typing
import uuid
from string import Template

import brotli
import orjson
import pyarrow as pa
from asgiref.sync import sync_to_async
from django.conf import settings
Expand Down Expand Up @@ -286,202 +281,6 @@ def get_data_interval(interval: str, data_interval_end: str | None) -> tuple[dt.
return (data_interval_start_dt, data_interval_end_dt)


def json_dumps_bytes(d) -> bytes:
return orjson.dumps(d, default=str)


class BatchExportTemporaryFile:
"""A TemporaryFile used to as an intermediate step while exporting data.
This class does not implement the file-like interface but rather passes any calls
to the underlying tempfile.NamedTemporaryFile. We do override 'write' methods
to allow tracking bytes and records.
"""

def __init__(
self,
mode: str = "w+b",
buffering=-1,
compression: str | None = None,
encoding: str | None = None,
newline: str | None = None,
suffix: str | None = None,
prefix: str | None = None,
dir: str | None = None,
*,
errors: str | None = None,
):
self._file = tempfile.NamedTemporaryFile(
mode=mode,
encoding=encoding,
newline=newline,
buffering=buffering,
suffix=suffix,
prefix=prefix,
dir=dir,
errors=errors,
)
self.compression = compression
self.bytes_total = 0
self.records_total = 0
self.bytes_since_last_reset = 0
self.records_since_last_reset = 0
self._brotli_compressor = None

def __getattr__(self, name):
"""Pass get attr to underlying tempfile.NamedTemporaryFile."""
return self._file.__getattr__(name)

def __enter__(self):
"""Context-manager protocol enter method."""
self._file.__enter__()
return self

def __exit__(self, exc, value, tb):
"""Context-manager protocol exit method."""
return self._file.__exit__(exc, value, tb)

def __iter__(self):
yield from self._file

@property
def brotli_compressor(self):
if self._brotli_compressor is None:
self._brotli_compressor = brotli.Compressor()
return self._brotli_compressor

def compress(self, content: bytes | str) -> bytes:
if isinstance(content, str):
encoded = content.encode("utf-8")
else:
encoded = content

match self.compression:
case "gzip":
return gzip.compress(encoded)
case "brotli":
self.brotli_compressor.process(encoded)
return self.brotli_compressor.flush()
case None:
return encoded
case _:
raise ValueError(f"Unsupported compression: '{self.compression}'")

def write(self, content: bytes | str):
"""Write bytes to underlying file keeping track of how many bytes were written."""
compressed_content = self.compress(content)

if "b" in self.mode:
result = self._file.write(compressed_content)
else:
result = self._file.write(compressed_content.decode("utf-8"))

self.bytes_total += result
self.bytes_since_last_reset += result

return result

def write_record_as_bytes(self, record: bytes):
result = self.write(record)

self.records_total += 1
self.records_since_last_reset += 1

return result

def write_records_to_jsonl(self, records):
"""Write records to a temporary file as JSONL."""
if len(records) == 1:
jsonl_dump = orjson.dumps(records[0], option=orjson.OPT_APPEND_NEWLINE, default=str)
else:
jsonl_dump = b"\n".join(map(json_dumps_bytes, records))

result = self.write(jsonl_dump)

self.records_total += len(records)
self.records_since_last_reset += len(records)

return result

def write_records_to_csv(
self,
records,
fieldnames: None | collections.abc.Sequence[str] = None,
extrasaction: typing.Literal["raise", "ignore"] = "ignore",
delimiter: str = ",",
quotechar: str = '"',
escapechar: str | None = "\\",
lineterminator: str = "\n",
quoting=csv.QUOTE_NONE,
):
"""Write records to a temporary file as CSV."""
if len(records) == 0:
return

if fieldnames is None:
fieldnames = list(records[0].keys())

writer = csv.DictWriter(
self,
fieldnames=fieldnames,
extrasaction=extrasaction,
delimiter=delimiter,
quotechar=quotechar,
escapechar=escapechar,
quoting=quoting,
lineterminator=lineterminator,
)
writer.writerows(records)

self.records_total += len(records)
self.records_since_last_reset += len(records)

def write_records_to_tsv(
self,
records,
fieldnames: None | list[str] = None,
extrasaction: typing.Literal["raise", "ignore"] = "ignore",
quotechar: str = '"',
escapechar: str | None = "\\",
lineterminator: str = "\n",
quoting=csv.QUOTE_NONE,
):
"""Write records to a temporary file as TSV."""
return self.write_records_to_csv(
records,
fieldnames=fieldnames,
extrasaction=extrasaction,
delimiter="\t",
quotechar=quotechar,
escapechar=escapechar,
quoting=quoting,
lineterminator=lineterminator,
)

def rewind(self):
"""Rewind the file before reading it."""
if self.compression == "brotli":
result = self._file.write(self.brotli_compressor.finish())

self.bytes_total += result
self.bytes_since_last_reset += result

self._brotli_compressor = None

self._file.seek(0)

def reset(self):
"""Reset underlying file by truncating it.
Also resets the tracker attributes for bytes and records since last reset.
"""
self._file.seek(0)
self._file.truncate()

self.bytes_since_last_reset = 0
self.records_since_last_reset = 0


@dataclasses.dataclass
class CreateBatchExportRunInputs:
"""Inputs to the create_export_run activity.
Expand Down
4 changes: 3 additions & 1 deletion posthog/temporal/batch_exports/bigquery_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from posthog.batch_exports.service import BatchExportField, BatchExportSchema, BigQueryBatchExportInputs
from posthog.temporal.batch_exports.base import PostHogWorkflow
from posthog.temporal.batch_exports.batch_exports import (
BatchExportTemporaryFile,
CreateBatchExportRunInputs,
UpdateBatchExportRunStatusInputs,
create_export_run,
Expand All @@ -29,6 +28,9 @@
get_bytes_exported_metric,
get_rows_exported_metric,
)
from posthog.temporal.batch_exports.temporary_file import (
BatchExportTemporaryFile,
)
from posthog.temporal.batch_exports.utils import peek_first_and_rewind
from posthog.temporal.common.clickhouse import get_client
from posthog.temporal.common.logger import bind_temporal_worker_logger
Expand Down
6 changes: 4 additions & 2 deletions posthog/temporal/batch_exports/http_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,22 @@
from posthog.models import BatchExportRun
from posthog.temporal.batch_exports.base import PostHogWorkflow
from posthog.temporal.batch_exports.batch_exports import (
BatchExportTemporaryFile,
CreateBatchExportRunInputs,
UpdateBatchExportRunStatusInputs,
create_export_run,
execute_batch_export_insert_activity,
get_data_interval,
get_rows_count,
iter_records,
json_dumps_bytes,
)
from posthog.temporal.batch_exports.metrics import (
get_bytes_exported_metric,
get_rows_exported_metric,
)
from posthog.temporal.batch_exports.temporary_file import (
BatchExportTemporaryFile,
json_dumps_bytes,
)
from posthog.temporal.common.clickhouse import get_client
from posthog.temporal.common.logger import bind_temporal_worker_logger

Expand Down
4 changes: 3 additions & 1 deletion posthog/temporal/batch_exports/postgres_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from posthog.batch_exports.service import BatchExportField, BatchExportSchema, PostgresBatchExportInputs
from posthog.temporal.batch_exports.base import PostHogWorkflow
from posthog.temporal.batch_exports.batch_exports import (
BatchExportTemporaryFile,
CreateBatchExportRunInputs,
UpdateBatchExportRunStatusInputs,
create_export_run,
Expand All @@ -31,6 +30,9 @@
get_bytes_exported_metric,
get_rows_exported_metric,
)
from posthog.temporal.batch_exports.temporary_file import (
BatchExportTemporaryFile,
)
from posthog.temporal.batch_exports.utils import peek_first_and_rewind
from posthog.temporal.common.clickhouse import get_client
from posthog.temporal.common.logger import bind_temporal_worker_logger
Expand Down
Loading

0 comments on commit 7c0258c

Please sign in to comment.