Skip to content

Commit

Permalink
test: Add writer classes tests and docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
tomasfarias committed Mar 21, 2024
1 parent d51400b commit 33f594f
Show file tree
Hide file tree
Showing 2 changed files with 250 additions and 20 deletions.
69 changes: 49 additions & 20 deletions posthog/temporal/batch_exports/temporary_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,36 +231,38 @@ class BatchExportWriter(abc.ABC):
Actual writing calls are passed to the underlying `batch_export_file`.
Attributes:
batch_export_file: The temporary file we are writing to.
bytes_flush_threshold: Flush the temporary file with the provided `flush_callable`
_batch_export_file: The temporary file we are writing to.
max_bytes: Flush the temporary file with the provided `flush_callable`
upon reaching or surpassing this threshold. Keep in mind we write on a RecordBatch
per RecordBatch basis, which means the threshold will be surpassed by at most the
size of a RecordBatch before a flush occurs.
flush_callable: A callback to flush the temporary file when `bytes_flush_treshold` is reached.
flush_callable: A callback to flush the temporary file when `max_bytes` is reached.
The temporary file will be reset after calling `flush_callable`.
file_kwargs: Optional keyword arguments passed when initializing `_batch_export_file`.
last_inserted_at: Latest `_inserted_at` written. This attribute leaks some implementation
details, as we are assuming assume `_inserted_at` is present, as it's added to all
batch export queries.
records_total: The total number of records (not RecordBatches!) written.
records_since_last_flush: The number of records written since last flush.
last_inserted_at: Latest `_inserted_at` written. This attribute leaks some implementation
details, as we are making two assumptions about the RecordBatches being written:
* We assume RecordBatches are sorted on `_inserted_at`, which currently happens with
an `ORDER BY` clause.
* We assume `_inserted_at` is present, as it's added to all batch export queries.
bytes_total: The total number of bytes written.
bytes_since_last_flush: The number of bytes written since last flush.
"""

def __init__(
self,
flush_callable: FlushCallable,
max_bytes: int,
file_kwargs: collections.abc.Mapping[str, typing.Any],
file_kwargs: collections.abc.Mapping[str, typing.Any] | None = None,
):
self.flush_callable = flush_callable
self.max_bytes = max_bytes
self.file_kwargs = file_kwargs
self.file_kwargs: collections.abc.Mapping[str, typing.Any] = file_kwargs or {}

self._batch_export_file: BatchExportTemporaryFile | None = None
self.reset_writer_tracking()

def reset_writer_tracking(self):
"""Reset this writer's tracking state."""
self.last_inserted_at: dt.datetime | None = None
self.records_total = 0
self.records_since_last_flush = 0
Expand All @@ -274,7 +276,8 @@ async def open_temporary_file(self):
The underlying `BatchExportTemporaryFile` is only accessible within this context manager. This helps
us separate the lifetime of the underlying temporary file from the writer: The writer may still be
accessed even after the temporary file is closed, while on the other hand we ensure the file and all
its data is flushed and not leaked outside the context. Any relevant tracking information
its data is flushed and not leaked outside the context. Any relevant tracking information is copied
to the writer.
"""
self.reset_writer_tracking()

Expand All @@ -297,6 +300,11 @@ async def open_temporary_file(self):

@property
def batch_export_file(self):
"""Property for underlying temporary file.
Raises:
ValueError: if attempting to access the temporary file before it has been opened.
"""
if self._batch_export_file is None:
raise ValueError("Batch export file is closed. Did you forget to call 'open_temporary_file'?")
return self._batch_export_file
Expand All @@ -311,16 +319,19 @@ def _write_record_batch(self, record_batch: pa.RecordBatch) -> None:
pass

def track_records_written(self, record_batch: pa.RecordBatch) -> None:
"""Update this writer's state with the number of records in `record_batch`."""
self.records_total += record_batch.num_rows
self.records_since_last_flush += record_batch.num_rows

def track_bytes_written(self, batch_export_file: BatchExportTemporaryFile) -> None:
"""Update this writer's state with the bytes in `batch_export_file`."""
self.bytes_total = batch_export_file.bytes_total
self.bytes_since_last_flush = batch_export_file.bytes_since_last_reset

async def write_record_batch(self, record_batch: pa.RecordBatch) -> None:
"""Issue a record batch write tracking progress and flushing if required."""
last_inserted_at = record_batch.column("_inserted_at")[0].as_py()
record_batch = record_batch.sort_by("_inserted_at")
last_inserted_at = record_batch.column("_inserted_at")[-1].as_py()

column_names = record_batch.column_names
column_names.pop(column_names.index("_inserted_at"))
Expand All @@ -336,6 +347,8 @@ async def write_record_batch(self, record_batch: pa.RecordBatch) -> None:

async def flush(self, last_inserted_at: dt.datetime, is_last: bool = False) -> None:
"""Call the provided `flush_callable` and reset underlying file."""
self.batch_export_file.seek(0)

await self.flush_callable(
self.batch_export_file,
self.records_since_last_flush,
Expand All @@ -350,7 +363,12 @@ async def flush(self, last_inserted_at: dt.datetime, is_last: bool = False) -> N


class JSONLBatchExportWriter(BatchExportWriter):
"""A `BatchExportWriter` for JSONLines format."""
"""A `BatchExportWriter` for JSONLines format.
Attributes:
default: The default function to use to cast non-serializable Python objects to serializable objects.
By default, non-serializable objects will be cast to string via `str()`.
"""

def __init__(
self,
Expand All @@ -368,6 +386,7 @@ def __init__(
self.default = default

def write(self, content: bytes) -> int:
"""Write a single row of JSONL."""
n = self.batch_export_file.write(orjson.dumps(content, default=str) + b"\n")
return n

Expand Down Expand Up @@ -430,26 +449,37 @@ def _write_record_batch(self, record_batch: pa.RecordBatch) -> None:


class ParquetBatchExportWriter(BatchExportWriter):
"""A `BatchExportWriter` for Apache Parquet format."""
"""A `BatchExportWriter` for Apache Parquet format.
We utilize and wrap a `pyarrow.parquet.ParquetWriter` to do the actual writing. We default to their
defaults for most parameters; however this class could be extended with more attributes to pass along
to `pyarrow.parquet.ParquetWriter`.
See the pyarrow docs for more details on what parameters can the writer be configured with:
https://arrow.apache.org/docs/python/generated/pyarrow.parquet.ParquetWriter.html
In contrast to other writers, instead of us handling compression we let `pyarrow.parquet.ParquetWriter`
handle it, so `BatchExportTemporaryFile` is always initialized with `compression=None`.
Attributes:
schema: The schema used by the Parquet file. Should match the schema of written RecordBatches.
compression: Compression codec passed to underlying `pyarrow.parquet.ParquetWriter`.
"""

def __init__(
self,
max_bytes: int,
flush_callable: FlushCallable,
schema: pa.Schema,
version: str = "2.6",
compression: str | None = "snappy",
compression_level: int | None = None,
):
super().__init__(
max_bytes=max_bytes,
flush_callable=flush_callable,
file_kwargs={"compression": None}, # ParquetWriter handles compression
)
self.schema = schema
self.version = version
self.compression = compression
self.compression_level = compression_level

self._parquet_writer: pq.ParquetWriter | None = None

Expand All @@ -459,10 +489,8 @@ def parquet_writer(self) -> pq.ParquetWriter:
self._parquet_writer = pq.ParquetWriter(
self.batch_export_file,
schema=self.schema,
version=self.version,
# Compression *can* be `None`.
compression=self.compression,
compression_level=self.compression_level,
)
return self._parquet_writer

Expand All @@ -485,4 +513,5 @@ async def open_temporary_file(self):

def _write_record_batch(self, record_batch: pa.RecordBatch) -> None:
"""Write records to a temporary file as Parquet."""

self.parquet_writer.write_batch(record_batch.select(self.parquet_writer.schema.names))
201 changes: 201 additions & 0 deletions posthog/temporal/tests/batch_exports/test_temporary_file.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,17 @@
import csv
import datetime as dt
import io
import json

import pyarrow as pa
import pyarrow.parquet as pq
import pytest

from posthog.temporal.batch_exports.temporary_file import (
BatchExportTemporaryFile,
CSVBatchExportWriter,
JSONLBatchExportWriter,
ParquetBatchExportWriter,
json_dumps_bytes,
)

Expand Down Expand Up @@ -186,3 +192,198 @@ def test_batch_export_temporary_file_write_records_to_tsv(records):
assert be_file.bytes_since_last_reset == 0
assert be_file.records_total == len(records)
assert be_file.records_since_last_reset == 0


TEST_RECORD_BATCHES = [
pa.RecordBatch.from_pydict(
{
"event": pa.array(["test-event-0", "test-event-1", "test-event-2"]),
"properties": pa.array(['{"prop_0": 1, "prop_1": 2}', "{}", "null"]),
"_inserted_at": pa.array([0, 1, 2]),
}
)
]


@pytest.mark.parametrize(
"record_batch",
TEST_RECORD_BATCHES,
)
@pytest.mark.asyncio
async def test_jsonl_writer_writes_record_batches(record_batch):
"""Test record batches are written as valid JSONL."""
in_memory_file_obj = io.BytesIO()
inserted_ats_seen = []

async def store_in_memory_on_flush(
batch_export_file, records_since_last_flush, bytes_since_last_flush, last_inserted_at, is_last
):
in_memory_file_obj.write(batch_export_file.read())
inserted_ats_seen.append(last_inserted_at)

writer = JSONLBatchExportWriter(max_bytes=1, flush_callable=store_in_memory_on_flush)

record_batch = record_batch.sort_by("_inserted_at")
async with writer.open_temporary_file():
await writer.write_record_batch(record_batch)

lines = in_memory_file_obj.readlines()
for index, line in enumerate(lines):
written_jsonl = json.loads(line)

single_record_batch = record_batch.slice(offset=index, length=1)
expected_jsonl = single_record_batch.to_pylist()[0]

assert "_inserted_at" not in written_jsonl
assert written_jsonl == expected_jsonl

assert inserted_ats_seen == [record_batch.column("_inserted_at")[-1].as_py()]


@pytest.mark.parametrize(
"record_batch",
TEST_RECORD_BATCHES,
)
@pytest.mark.asyncio
async def test_csv_writer_writes_record_batches(record_batch):
"""Test record batches are written as valid CSV."""
in_memory_file_obj = io.StringIO()
inserted_ats_seen = []

async def store_in_memory_on_flush(
batch_export_file, records_since_last_flush, bytes_since_last_flush, last_inserted_at, is_last
):
in_memory_file_obj.write(batch_export_file.read().decode("utf-8"))
inserted_ats_seen.append(last_inserted_at)

schema_columns = [column_name for column_name in record_batch.column_names if column_name != "_inserted_at"]
writer = CSVBatchExportWriter(max_bytes=1, field_names=schema_columns, flush_callable=store_in_memory_on_flush)

record_batch = record_batch.sort_by("_inserted_at")
async with writer.open_temporary_file():
await writer.write_record_batch(record_batch)

reader = csv.reader(
in_memory_file_obj,
delimiter=",",
quotechar='"',
escapechar="\\",
quoting=csv.QUOTE_NONE,
)
for index, written_csv_row in enumerate(reader):
single_record_batch = record_batch.slice(offset=index, length=1)
expected_csv = single_record_batch.to_pylist()[0]

assert "_inserted_at" not in written_csv_row
assert written_csv_row == expected_csv

assert inserted_ats_seen == [record_batch.column("_inserted_at")[-1].as_py()]


@pytest.mark.parametrize(
"record_batch",
TEST_RECORD_BATCHES,
)
@pytest.mark.asyncio
async def test_parquet_writer_writes_record_batches(record_batch):
"""Test record batches are written as valid Parquet."""
in_memory_file_obj = io.BytesIO()
inserted_ats_seen = []

async def store_in_memory_on_flush(
batch_export_file, records_since_last_flush, bytes_since_last_flush, last_inserted_at, is_last
):
in_memory_file_obj.write(batch_export_file.read())
inserted_ats_seen.append(last_inserted_at)

schema_columns = [column_name for column_name in record_batch.column_names if column_name != "_inserted_at"]

writer = ParquetBatchExportWriter(
max_bytes=1,
flush_callable=store_in_memory_on_flush,
schema=record_batch.select(schema_columns).schema,
)

record_batch = record_batch.sort_by("_inserted_at")
async with writer.open_temporary_file():
await writer.write_record_batch(record_batch)

written_parquet = pq.read_table(in_memory_file_obj)

for index, written_row_as_dict in enumerate(written_parquet.to_pylist()):
single_record_batch = record_batch.slice(offset=index, length=1)
expected_row_as_dict = single_record_batch.select(schema_columns).to_pylist()[0]

assert "_inserted_at" not in written_row_as_dict
assert written_row_as_dict == expected_row_as_dict

# NOTE: Parquet gets flushed twice due to the extra flush at the end for footer bytes, so our mock function
# will see this value twice.
assert inserted_ats_seen == [
record_batch.column("_inserted_at")[-1].as_py(),
record_batch.column("_inserted_at")[-1].as_py(),
]


@pytest.mark.parametrize(
"record_batch",
TEST_RECORD_BATCHES,
)
@pytest.mark.asyncio
async def test_writing_out_of_scope_of_temporary_file_raises(record_batch):
"""Test attempting a write out of temporary file scope raises a `ValueError`."""

async def do_nothing(*args, **kwargs):
pass

schema_columns = [column_name for column_name in record_batch.column_names if column_name != "_inserted_at"]
writer = ParquetBatchExportWriter(
max_bytes=10,
flush_callable=do_nothing,
schema=record_batch.select(schema_columns).schema,
)

async with writer.open_temporary_file():
pass

with pytest.raises(ValueError, match="Batch export file is closed"):
await writer.write_record_batch(record_batch)


@pytest.mark.parametrize(
"record_batch",
TEST_RECORD_BATCHES,
)
@pytest.mark.asyncio
async def test_flushing_parquet_writer_resets_underlying_file(record_batch):
"""Test flushing a writer resets underlying file."""
flush_counter = 0

async def track_flushes(*args, **kwargs):
nonlocal flush_counter
flush_counter += 1

schema_columns = [column_name for column_name in record_batch.column_names if column_name != "_inserted_at"]
writer = ParquetBatchExportWriter(
max_bytes=10000000,
flush_callable=track_flushes,
schema=record_batch.select(schema_columns).schema,
)

async with writer.open_temporary_file():
await writer.write_record_batch(record_batch)

assert writer.batch_export_file.tell() > 0
assert writer.bytes_since_last_flush > 0
assert writer.bytes_since_last_flush == writer.batch_export_file.bytes_since_last_reset
assert writer.records_since_last_flush == record_batch.num_rows

await writer.flush(dt.datetime.now())

assert flush_counter == 1
assert writer.batch_export_file.tell() == 0
assert writer.bytes_since_last_flush == 0
assert writer.bytes_since_last_flush == writer.batch_export_file.bytes_since_last_reset
assert writer.records_since_last_flush == 0

assert flush_counter == 2

0 comments on commit 33f594f

Please sign in to comment.