Skip to content

Commit

Permalink
refactor: Support for multiple file formats in batch exports
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasfarias committed Mar 18, 2024
1 parent e85168c commit 148e0da
Show file tree
Hide file tree
Showing 4 changed files with 396 additions and 53 deletions.
1 change: 1 addition & 0 deletions posthog/batch_exports/service.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ class S3BatchExportInputs:
kms_key_id: str | None = None
batch_export_schema: BatchExportSchema | None = None
endpoint_url: str | None = None
file_format: str = "JSONLines"


@dataclass
Expand Down
164 changes: 164 additions & 0 deletions posthog/temporal/batch_exports/batch_exports.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import brotli
import orjson
import pyarrow as pa
import pyarrow.parquet as pq
from asgiref.sync import sync_to_async
from django.conf import settings
from temporalio import activity, exceptions, workflow
Expand Down Expand Up @@ -482,6 +483,169 @@ def reset(self):
self.records_since_last_reset = 0


FlushCallable = collections.abc.Callable[
[BatchExportTemporaryFile, int, int, dt.datetime, bool], collections.abc.Awaitable[None]
]


class BatchExportWriter(typing.Protocol):
batch_export_file: BatchExportTemporaryFile
flush_callable: FlushCallable
records_total: int = 0
records_since_last_flush: int = 0
max_bytes: int = 0
last_inserted_at: dt.datetime | None = None

async def __aenter__(self):
"""Context-manager protocol enter method."""
self.batch_export_file.__enter__()
return self

async def __aexit__(self, exc, value, tb):
"""Context-manager protocol exit method."""
if self.last_inserted_at is not None and self.records_since_last_flush > 0:
await self.flush(self.last_inserted_at, is_last=True)
return self.batch_export_file.__exit__(exc, value, tb)

def _write_record_batch(self, record_batch: pa.RecordBatch) -> None:
...

async def write_record_batch(self, record_batch: pa.RecordBatch) -> None:
last_inserted_at = record_batch.column("_inserted_at")[0].as_py()

column_names = record_batch.column_names
column_names.pop(column_names.index("_inserted_at"))

self._write_record_batch(record_batch.select(column_names))

self.last_inserted_at = last_inserted_at

if self.bytes_since_last_flush >= self.max_bytes:
await self.flush(last_inserted_at)

async def flush(self, last_inserted_at: dt.datetime, is_last: bool = False) -> None:
await self.flush_callable(
self.batch_export_file,
self.records_since_last_flush,
self.bytes_since_last_flush,
last_inserted_at,
is_last,
)
self.batch_export_file.reset()
self.records_since_last_flush = 0

@property
def bytes_since_last_flush(self) -> int:
return self.batch_export_file.bytes_since_last_reset


class JSONLBatchExportWriter(BatchExportWriter):
def __init__(
self,
max_bytes: int,
flush_callable: FlushCallable,
compression: None | str = None,
default: typing.Callable = str,
):
self.batch_export_file = BatchExportTemporaryFile(compression=compression)
self.flush_callable = flush_callable

self.default = default

self.max_bytes = max_bytes
self.last_inserted_at = None

def write(self, content: bytes) -> int:
n = self.batch_export_file.write(orjson.dumps(content, default=str) + b"\n")
return n

def _write_record_batch(self, record_batch: pa.RecordBatch) -> None:
"""Write records to a temporary file as JSONL."""
for record in record_batch.to_pylist():
self.write(record)

self.records_total += 1
self.records_since_last_flush += 1


class CSVBatchExportWriter(csv.DictWriter, BatchExportWriter):
def __init__(
self,
max_bytes: int,
flush_callable: FlushCallable,
field_names: collections.abc.Sequence[str],
extras_action: typing.Literal["raise", "ignore"] = "ignore",
delimiter: str = ",",
quote_char: str = '"',
escape_char: str | None = "\\",
line_terminator: str = "\n",
quoting=csv.QUOTE_NONE,
compression: str | None = None,
):
self.batch_export_file = BatchExportTemporaryFile(compression=compression)
self.flush_callable = flush_callable

super().__init__(
self.batch_export_file,
fieldnames=field_names,
extrasaction=extras_action,
delimiter=delimiter,
quotechar=quote_char,
escapechar=escape_char,
quoting=quoting,
lineterminator=line_terminator,
)

self.max_bytes = max_bytes
self.last_inserted_at = None

def _write_record_batch(self, record_batch: pa.RecordBatch) -> None:
"""Write records to a temporary file as JSONL."""
self.writerows(record_batch.to_pylist())
self.records_total += record_batch.num_rows
self.records_since_last_flush += record_batch.num_rows


class ParquetBatchExportWriter(pq.ParquetWriter, BatchExportWriter):
def __init__(
self,
max_bytes: int,
flush_callable: FlushCallable,
schema: pa.Schema,
version: str = "2.6",
use_dictionary: bool = True,
compression: str | None = "snappy",
compression_level: int | None = None,
):
self.batch_export_file = BatchExportTemporaryFile(compression=None) # Handle compression in ParquetWriter
self.flush_callable = flush_callable

super().__init__(
self.batch_export_file,
schema=schema,
version=version,
use_dictionary=use_dictionary,
compression=compression,
compression_level=compression_level,
)

self.max_bytes = max_bytes
self.last_inserted_at = None

async def __aexit__(self, exc, value, tb):
"""Close underlying Parquet writer to include footer bytes before flushing last."""
self.writer.close()
self.is_open = False

await super().__aexit__(exc, value, tb)

def _write_record_batch(self, record_batch: pa.RecordBatch) -> None:
"""Write records to a temporary file as JSONL."""
self.write_batch(record_batch.select(self.schema.names))
self.records_total += record_batch.num_rows
self.records_since_last_flush += record_batch.num_rows


@dataclasses.dataclass
class CreateBatchExportRunInputs:
"""Inputs to the create_export_run activity.
Expand Down
Loading

0 comments on commit 148e0da

Please sign in to comment.