Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Support for multiple file formats in S3 batch exports #20979

Merged
merged 14 commits into from
Mar 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions posthog/batch_exports/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ class S3BatchExportInputs:
kms_key_id: str | None = None
batch_export_schema: BatchExportSchema | None = None
endpoint_url: str | None = None
file_format: str = "JSONLines"


@dataclass
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
FROM events LEFT JOIN (
SELECT person_static_cohort.person_id AS cohort_person_id, 1 AS matched, person_static_cohort.cohort_id AS cohort_id
FROM person_static_cohort
WHERE and(equals(person_static_cohort.team_id, 420), in(person_static_cohort.cohort_id, [12]))) AS __in_cohort ON equals(__in_cohort.cohort_person_id, events.person_id)
WHERE and(equals(person_static_cohort.team_id, 420), in(person_static_cohort.cohort_id, [11]))) AS __in_cohort ON equals(__in_cohort.cohort_person_id, events.person_id)
WHERE and(equals(events.team_id, 420), 1, ifNull(equals(__in_cohort.matched, 1), 0))
LIMIT 100
SETTINGS readonly=2, max_execution_time=60, allow_experimental_object_type=1
Expand All @@ -42,7 +42,7 @@
FROM events LEFT JOIN (
SELECT person_id AS cohort_person_id, 1 AS matched, cohort_id
FROM static_cohort_people
WHERE in(cohort_id, [12])) AS __in_cohort ON equals(__in_cohort.cohort_person_id, person_id)
WHERE in(cohort_id, [11])) AS __in_cohort ON equals(__in_cohort.cohort_person_id, person_id)
WHERE and(1, equals(__in_cohort.matched, 1))
LIMIT 100
'''
Expand All @@ -55,7 +55,7 @@
FROM events LEFT JOIN (
SELECT person_static_cohort.person_id AS cohort_person_id, 1 AS matched, person_static_cohort.cohort_id AS cohort_id
FROM person_static_cohort
WHERE and(equals(person_static_cohort.team_id, 420), in(person_static_cohort.cohort_id, [13]))) AS __in_cohort ON equals(__in_cohort.cohort_person_id, events.person_id)
WHERE and(equals(person_static_cohort.team_id, 420), in(person_static_cohort.cohort_id, [12]))) AS __in_cohort ON equals(__in_cohort.cohort_person_id, events.person_id)
WHERE and(equals(events.team_id, 420), 1, ifNull(equals(__in_cohort.matched, 1), 0))
LIMIT 100
SETTINGS readonly=2, max_execution_time=60, allow_experimental_object_type=1
Expand All @@ -66,7 +66,7 @@
FROM events LEFT JOIN (
SELECT person_id AS cohort_person_id, 1 AS matched, cohort_id
FROM static_cohort_people
WHERE in(cohort_id, [13])) AS __in_cohort ON equals(__in_cohort.cohort_person_id, person_id)
WHERE in(cohort_id, [12])) AS __in_cohort ON equals(__in_cohort.cohort_person_id, person_id)
WHERE and(1, equals(__in_cohort.matched, 1))
LIMIT 100
'''
Expand Down
201 changes: 0 additions & 201 deletions posthog/temporal/batch_exports/batch_exports.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,10 @@
import collections.abc
import csv
import dataclasses
import datetime as dt
import gzip
import tempfile
import typing
import uuid
from string import Template

import brotli
import orjson
import pyarrow as pa
from asgiref.sync import sync_to_async
from django.conf import settings
Expand Down Expand Up @@ -286,202 +281,6 @@ def get_data_interval(interval: str, data_interval_end: str | None) -> tuple[dt.
return (data_interval_start_dt, data_interval_end_dt)


def json_dumps_bytes(d) -> bytes:
return orjson.dumps(d, default=str)


class BatchExportTemporaryFile:
"""A TemporaryFile used to as an intermediate step while exporting data.

This class does not implement the file-like interface but rather passes any calls
to the underlying tempfile.NamedTemporaryFile. We do override 'write' methods
to allow tracking bytes and records.
"""

def __init__(
self,
mode: str = "w+b",
buffering=-1,
compression: str | None = None,
encoding: str | None = None,
newline: str | None = None,
suffix: str | None = None,
prefix: str | None = None,
dir: str | None = None,
*,
errors: str | None = None,
):
self._file = tempfile.NamedTemporaryFile(
mode=mode,
encoding=encoding,
newline=newline,
buffering=buffering,
suffix=suffix,
prefix=prefix,
dir=dir,
errors=errors,
)
self.compression = compression
self.bytes_total = 0
self.records_total = 0
self.bytes_since_last_reset = 0
self.records_since_last_reset = 0
self._brotli_compressor = None

def __getattr__(self, name):
"""Pass get attr to underlying tempfile.NamedTemporaryFile."""
return self._file.__getattr__(name)

def __enter__(self):
"""Context-manager protocol enter method."""
self._file.__enter__()
return self

def __exit__(self, exc, value, tb):
"""Context-manager protocol exit method."""
return self._file.__exit__(exc, value, tb)

def __iter__(self):
yield from self._file

@property
def brotli_compressor(self):
if self._brotli_compressor is None:
self._brotli_compressor = brotli.Compressor()
return self._brotli_compressor

def compress(self, content: bytes | str) -> bytes:
if isinstance(content, str):
encoded = content.encode("utf-8")
else:
encoded = content

match self.compression:
case "gzip":
return gzip.compress(encoded)
case "brotli":
self.brotli_compressor.process(encoded)
return self.brotli_compressor.flush()
case None:
return encoded
case _:
raise ValueError(f"Unsupported compression: '{self.compression}'")

def write(self, content: bytes | str):
"""Write bytes to underlying file keeping track of how many bytes were written."""
compressed_content = self.compress(content)

if "b" in self.mode:
result = self._file.write(compressed_content)
else:
result = self._file.write(compressed_content.decode("utf-8"))

self.bytes_total += result
self.bytes_since_last_reset += result

return result

def write_record_as_bytes(self, record: bytes):
result = self.write(record)

self.records_total += 1
self.records_since_last_reset += 1

return result

def write_records_to_jsonl(self, records):
"""Write records to a temporary file as JSONL."""
if len(records) == 1:
jsonl_dump = orjson.dumps(records[0], option=orjson.OPT_APPEND_NEWLINE, default=str)
else:
jsonl_dump = b"\n".join(map(json_dumps_bytes, records))

result = self.write(jsonl_dump)

self.records_total += len(records)
self.records_since_last_reset += len(records)

return result

def write_records_to_csv(
self,
records,
fieldnames: None | collections.abc.Sequence[str] = None,
extrasaction: typing.Literal["raise", "ignore"] = "ignore",
delimiter: str = ",",
quotechar: str = '"',
escapechar: str | None = "\\",
lineterminator: str = "\n",
quoting=csv.QUOTE_NONE,
):
"""Write records to a temporary file as CSV."""
if len(records) == 0:
return

if fieldnames is None:
fieldnames = list(records[0].keys())

writer = csv.DictWriter(
self,
fieldnames=fieldnames,
extrasaction=extrasaction,
delimiter=delimiter,
quotechar=quotechar,
escapechar=escapechar,
quoting=quoting,
lineterminator=lineterminator,
)
writer.writerows(records)

self.records_total += len(records)
self.records_since_last_reset += len(records)

def write_records_to_tsv(
self,
records,
fieldnames: None | list[str] = None,
extrasaction: typing.Literal["raise", "ignore"] = "ignore",
quotechar: str = '"',
escapechar: str | None = "\\",
lineterminator: str = "\n",
quoting=csv.QUOTE_NONE,
):
"""Write records to a temporary file as TSV."""
return self.write_records_to_csv(
records,
fieldnames=fieldnames,
extrasaction=extrasaction,
delimiter="\t",
quotechar=quotechar,
escapechar=escapechar,
quoting=quoting,
lineterminator=lineterminator,
)

def rewind(self):
"""Rewind the file before reading it."""
if self.compression == "brotli":
result = self._file.write(self.brotli_compressor.finish())

self.bytes_total += result
self.bytes_since_last_reset += result

self._brotli_compressor = None

self._file.seek(0)

def reset(self):
"""Reset underlying file by truncating it.

Also resets the tracker attributes for bytes and records since last reset.
"""
self._file.seek(0)
self._file.truncate()

self.bytes_since_last_reset = 0
self.records_since_last_reset = 0


@dataclasses.dataclass
class CreateBatchExportRunInputs:
"""Inputs to the create_export_run activity.
Expand Down
4 changes: 3 additions & 1 deletion posthog/temporal/batch_exports/bigquery_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from posthog.batch_exports.service import BatchExportField, BatchExportSchema, BigQueryBatchExportInputs
from posthog.temporal.batch_exports.base import PostHogWorkflow
from posthog.temporal.batch_exports.batch_exports import (
BatchExportTemporaryFile,
CreateBatchExportRunInputs,
UpdateBatchExportRunStatusInputs,
create_export_run,
Expand All @@ -29,6 +28,9 @@
get_bytes_exported_metric,
get_rows_exported_metric,
)
from posthog.temporal.batch_exports.temporary_file import (
BatchExportTemporaryFile,
)
from posthog.temporal.batch_exports.utils import peek_first_and_rewind
from posthog.temporal.common.clickhouse import get_client
from posthog.temporal.common.logger import bind_temporal_worker_logger
Expand Down
6 changes: 4 additions & 2 deletions posthog/temporal/batch_exports/http_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,20 +13,22 @@
from posthog.models import BatchExportRun
from posthog.temporal.batch_exports.base import PostHogWorkflow
from posthog.temporal.batch_exports.batch_exports import (
BatchExportTemporaryFile,
CreateBatchExportRunInputs,
UpdateBatchExportRunStatusInputs,
create_export_run,
execute_batch_export_insert_activity,
get_data_interval,
get_rows_count,
iter_records,
json_dumps_bytes,
)
from posthog.temporal.batch_exports.metrics import (
get_bytes_exported_metric,
get_rows_exported_metric,
)
from posthog.temporal.batch_exports.temporary_file import (
BatchExportTemporaryFile,
json_dumps_bytes,
)
from posthog.temporal.common.clickhouse import get_client
from posthog.temporal.common.logger import bind_temporal_worker_logger

Expand Down
4 changes: 3 additions & 1 deletion posthog/temporal/batch_exports/postgres_batch_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
from posthog.batch_exports.service import BatchExportField, BatchExportSchema, PostgresBatchExportInputs
from posthog.temporal.batch_exports.base import PostHogWorkflow
from posthog.temporal.batch_exports.batch_exports import (
BatchExportTemporaryFile,
CreateBatchExportRunInputs,
UpdateBatchExportRunStatusInputs,
create_export_run,
Expand All @@ -31,6 +30,9 @@
get_bytes_exported_metric,
get_rows_exported_metric,
)
from posthog.temporal.batch_exports.temporary_file import (
BatchExportTemporaryFile,
)
from posthog.temporal.batch_exports.utils import peek_first_and_rewind
from posthog.temporal.common.clickhouse import get_client
from posthog.temporal.common.logger import bind_temporal_worker_logger
Expand Down
Loading
Loading