Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove context processing options #160

Merged
merged 1 commit into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
76 changes: 28 additions & 48 deletions src/nnbench/reporter/file.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,16 @@
import os
import threading
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import IO, Any, Literal, cast

from nnbench.reporter.base import BenchmarkReporter
from nnbench.types import BenchmarkRecord


@dataclass(frozen=True)
class FileDriverOptions:
options: dict[str, Any] = field(default_factory=dict)
"""Options to pass to the underlying serialization API call, e.g. ``json.dump``."""
ctxmode: Literal["flatten", "inline", "omit"] = "inline"
"""How to handle the context struct."""


_Options = dict[str, Any]
SerDe = tuple[
Callable[[Sequence[BenchmarkRecord], IO, FileDriverOptions], None],
Callable[[IO, FileDriverOptions], list[BenchmarkRecord]],
Callable[[Sequence[BenchmarkRecord], IO, dict[str, Any]], None],
Callable[[IO, dict[str, Any]], list[BenchmarkRecord]],
]


Expand All @@ -30,80 +20,79 @@ class FileDriverOptions:
_compression_lock = threading.Lock()


def yaml_save(records: Sequence[BenchmarkRecord], fp: IO, fdoptions: FileDriverOptions) -> None:
def yaml_save(records: Sequence[BenchmarkRecord], fp: IO, options: dict[str, Any]) -> None:
try:
import yaml
except ImportError:
raise ModuleNotFoundError("`pyyaml` is not installed")

bms = []
for r in records:
bms += r.compact(mode=fdoptions.ctxmode)
yaml.safe_dump(bms, fp, **fdoptions.options)
bms += r.compact()
yaml.safe_dump(bms, fp, **options)


def yaml_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]:
def yaml_load(fp: IO, options: dict[str, Any]) -> list[BenchmarkRecord]:
try:
import yaml
except ImportError:
raise ModuleNotFoundError("`pyyaml` is not installed")

# TODO: Use expandmany()
bms = yaml.safe_load(fp)
return [BenchmarkRecord.expand(bms)]


def json_save(records: Sequence[BenchmarkRecord], fp: IO, fdoptions: FileDriverOptions) -> None:
def json_save(records: Sequence[BenchmarkRecord], fp: IO, options: dict[str, Any]) -> None:
import json

bm = []
for r in records:
bm += r.compact(mode=fdoptions.ctxmode)
json.dump(bm, fp, **fdoptions.options)
bm += r.compact()
json.dump(bm, fp, **options)


def json_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]:
def json_load(fp: IO, options: dict[str, Any]) -> list[BenchmarkRecord]:
import json

benchmarks: list[dict[str, Any]] = json.load(fp, **fdoptions.options)
benchmarks: list[dict[str, Any]] = json.load(fp, **options)
return [BenchmarkRecord.expand(benchmarks)]


def ndjson_save(records: Sequence[BenchmarkRecord], fp: IO, fdoptions: FileDriverOptions) -> None:
def ndjson_save(records: Sequence[BenchmarkRecord], fp: IO, options: dict[str, Any]) -> None:
import json

bm = []
for r in records:
bm += r.compact(mode=fdoptions.ctxmode)
fp.write("\n".join([json.dumps(b) for b in bm]))
bm += r.compact()
fp.write("\n".join([json.dumps(b, **options) for b in bm]))


def ndjson_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]:
def ndjson_load(fp: IO, options: dict[str, Any]) -> list[BenchmarkRecord]:
import json

benchmarks: list[dict[str, Any]]
benchmarks = [json.loads(line, **fdoptions.options) for line in fp]
benchmarks = [json.loads(line, **options) for line in fp]
return [BenchmarkRecord.expand(benchmarks)]


def csv_save(records: Sequence[BenchmarkRecord], fp: IO, fdoptions: FileDriverOptions) -> None:
def csv_save(records: Sequence[BenchmarkRecord], fp: IO, options: dict[str, Any]) -> None:
import csv

bm = []
for r in records:
bm += r.compact(mode=fdoptions.ctxmode)
writer = csv.DictWriter(fp, fieldnames=bm[0].keys(), **fdoptions.options)
bm += r.compact()
writer = csv.DictWriter(fp, fieldnames=bm[0].keys(), **options)
writer.writeheader()

for b in bm:
writer.writerow(b)


def csv_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]:
def csv_load(fp: IO, options: dict[str, Any]) -> list[BenchmarkRecord]:
import csv
import json

reader = csv.DictReader(fp, **fdoptions.options)
reader = csv.DictReader(fp, **options)

benchmarks: list[dict[str, Any]] = []
# apparently csv.DictReader has no appropriate type hint for __next__,
Expand All @@ -122,22 +111,22 @@ def csv_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]:
return [BenchmarkRecord.expand(benchmarks)]


def parquet_save(records: Sequence[BenchmarkRecord], fp: IO, fdoptions: FileDriverOptions) -> None:
def parquet_save(records: Sequence[BenchmarkRecord], fp: IO, options: dict[str, Any]) -> None:
import pyarrow as pa
import pyarrow.parquet as pq

bm = []
for r in records:
bm += r.compact(mode=fdoptions.ctxmode)
bm += r.compact()

table = pa.Table.from_pylist(bm)
pq.write_table(table, fp, **fdoptions.options)
pq.write_table(table, fp, **options)


def parquet_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]:
def parquet_load(fp: IO, options: dict[str, Any]) -> list[BenchmarkRecord]:
import pyarrow.parquet as pq

table = pq.read_table(fp, **fdoptions.options)
table = pq.read_table(fp, **options)
benchmarks: list[dict[str, Any]] = table.to_pylist()
return [BenchmarkRecord.expand(benchmarks)]

Expand Down Expand Up @@ -292,11 +281,8 @@ def read_batched(
else:
fd = open(file, mode)

# dummy value, since the context mode is unused in read ops.
fdoptions = FileDriverOptions(ctxmode="omit", options=options or {})

with fd as fp:
return de(fp, fdoptions)
return de(fp, options or {})

def write(
self,
Expand All @@ -305,7 +291,6 @@ def write(
mode: str = "w",
driver: str | None = None,
compression: str | None = None,
ctxmode: Literal["flatten", "inline", "omit"] = "inline",
options: dict[str, Any] | None = None,
) -> None:
"""Greedy version of ``FileIO.write_batched()``"""
Expand All @@ -315,7 +300,6 @@ def write(
mode=mode,
driver=driver,
compression=compression,
ctxmode=ctxmode,
options=options,
)

Expand All @@ -326,7 +310,6 @@ def write_batched(
mode: str = "w",
driver: str | None = None,
compression: str | None = None,
ctxmode: Literal["flatten", "inline", "omit"] = "inline",
options: dict[str, Any] | None = None,
) -> None:
"""
Expand All @@ -348,8 +331,6 @@ def write_batched(
compression: str | None
Compression engine to use. If None, the compression inferred from the given
file path's extension will be used.
ctxmode: Literal["flatten", "inline", "omit"]
How to handle the benchmark context when writing the record data.
options: dict[str, Any] | None
Options to pass to the respective file driver implementation.

Expand Down Expand Up @@ -382,9 +363,8 @@ def write_batched(
else:
fd = open(file, mode)

fdoptions = FileDriverOptions(ctxmode=ctxmode, options=options or {})
with fd as fp:
ser(records, fp, fdoptions)
ser(records, fp, options or {})


class FileReporter(FileIO, BenchmarkReporter):
Expand Down
31 changes: 8 additions & 23 deletions src/nnbench/types/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Callable, Mapping
from dataclasses import dataclass, field
from types import MappingProxyType
from typing import Any, Literal
from typing import Any

if sys.version_info >= (3, 11):
from typing import Self
Expand All @@ -32,32 +32,23 @@ class BenchmarkRecord:
context: dict[str, Any]
benchmarks: list[dict[str, Any]]

def compact(
self,
mode: Literal["flatten", "inline", "omit"] = "inline",
sep: str = ".",
) -> list[dict[str, Any]]:
def compact(self, sep: str = ".") -> list[dict[str, Any]]:
"""
Prepare the benchmark results, optionally inlining the context either as a
nested dictionary or in flattened form.

Parameters
----------
mode: Literal["flatten", "inline", "omit"]
How to handle the context. ``"omit"`` leaves out the context entirely, ``"inline"``
inserts it into the benchmark dictionary as a single entry named ``"context"``, and
``"flatten"`` inserts the flattened context values into the dictionary.
sep: str
The separator to use when flattening the context, i.e. when ``mode = "flatten"``.
Deprecated, unused.

Returns
-------
list[dict[str, Any]]
The updated list of benchmark records.
"""
if mode == "omit":
return self.benchmarks

# TODO: Allow keeping data as top-level struct?
# i.e. .compact(inline=False) -> { "context": {...}, "benchmarks": [...] }
result = []

for b in self.benchmarks:
Expand Down Expand Up @@ -86,18 +77,12 @@ def expand(cls, bms: list[dict[str, Any]]) -> Self:
"""
ctx: dict[str, Any] = {}
for b in bms:
# Safeguard if the context is not in the deser'd record,
# for example if the record came from a DB query.
if "context" in b:
ctx = b.pop("context")
elif "_contextkeys" in b:
ctxkeys = b.pop("_contextkeys")
for k in ctxkeys:
# This should never throw, save for data corruption.
ctx[k] = b.pop(k)
ctx = b.pop("context", {})
return cls(context=ctx, benchmarks=bms)

# TODO: Add an expandmany() API for returning a sequence of records for heterogeneous
# context data.


@dataclass(frozen=True)
class Benchmark:
Expand Down
12 changes: 4 additions & 8 deletions tests/test_fileio.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import itertools
from pathlib import Path
from typing import Literal

import pytest

Expand All @@ -9,12 +7,10 @@


@pytest.mark.parametrize(
"ext,ctxmode",
itertools.product(["yaml", "json", "ndjson", "csv", "parquet"], ["inline", "flatten"]),
"ext",
["yaml", "json", "ndjson", "csv", "parquet"],
)
def test_fileio_writes_no_compression_inline(
tmp_path: Path, ext: str, ctxmode: Literal["inline", "flatten"]
) -> None:
def test_fileio_writes_no_compression_inline(tmp_path: Path, ext: str) -> None:
"""Tests data integrity for file IO roundtrips with both context modes."""
f = FileIO()

Expand All @@ -24,7 +20,7 @@ def test_fileio_writes_no_compression_inline(
)
file = tmp_path / f"record.{ext}"
writemode = "wb" if ext == "parquet" else "w"
f.write(rec, file, mode=writemode, ctxmode="inline")
f.write(rec, file, mode=writemode)
readmode = "rb" if ext == "parquet" else "r"
rec2 = f.read(file, mode=readmode)
# Python stdlib csv coerces everything to string.
Expand Down