Skip to content

Commit

Permalink
Remove context processing options
Browse files Browse the repository at this point in the history
The context struct can be passed directly in all options we currently
consider, so we don't give the option anymore.
  • Loading branch information
nicholasjng committed Nov 11, 2024
1 parent 3b27774 commit d51e11f
Show file tree
Hide file tree
Showing 3 changed files with 40 additions and 79 deletions.
76 changes: 28 additions & 48 deletions src/nnbench/reporter/file.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,16 @@
import os
import threading
from collections.abc import Callable, Sequence
from dataclasses import dataclass, field
from pathlib import Path
from typing import IO, Any, Literal, cast

from nnbench.reporter.base import BenchmarkReporter
from nnbench.types import BenchmarkRecord


@dataclass(frozen=True)
class FileDriverOptions:
options: dict[str, Any] = field(default_factory=dict)
"""Options to pass to the underlying serialization API call, e.g. ``json.dump``."""
ctxmode: Literal["flatten", "inline", "omit"] = "inline"
"""How to handle the context struct."""


_Options = dict[str, Any]
SerDe = tuple[
Callable[[Sequence[BenchmarkRecord], IO, FileDriverOptions], None],
Callable[[IO, FileDriverOptions], list[BenchmarkRecord]],
Callable[[Sequence[BenchmarkRecord], IO, dict[str, Any]], None],
Callable[[IO, dict[str, Any]], list[BenchmarkRecord]],
]


Expand All @@ -30,80 +20,79 @@ class FileDriverOptions:
_compression_lock = threading.Lock()


def yaml_save(records: Sequence[BenchmarkRecord], fp: IO, fdoptions: FileDriverOptions) -> None:
def yaml_save(records: Sequence[BenchmarkRecord], fp: IO, options: dict[str, Any]) -> None:
try:
import yaml
except ImportError:
raise ModuleNotFoundError("`pyyaml` is not installed")

bms = []
for r in records:
bms += r.compact(mode=fdoptions.ctxmode)
yaml.safe_dump(bms, fp, **fdoptions.options)
bms += r.compact()
yaml.safe_dump(bms, fp, **options)


def yaml_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]:
def yaml_load(fp: IO, options: dict[str, Any]) -> list[BenchmarkRecord]:
try:
import yaml
except ImportError:
raise ModuleNotFoundError("`pyyaml` is not installed")

# TODO: Use expandmany()
bms = yaml.safe_load(fp)
return [BenchmarkRecord.expand(bms)]


def json_save(records: Sequence[BenchmarkRecord], fp: IO, fdoptions: FileDriverOptions) -> None:
def json_save(records: Sequence[BenchmarkRecord], fp: IO, options: dict[str, Any]) -> None:
import json

bm = []
for r in records:
bm += r.compact(mode=fdoptions.ctxmode)
json.dump(bm, fp, **fdoptions.options)
bm += r.compact()
json.dump(bm, fp, **options)


def json_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]:
def json_load(fp: IO, options: dict[str, Any]) -> list[BenchmarkRecord]:
import json

benchmarks: list[dict[str, Any]] = json.load(fp, **fdoptions.options)
benchmarks: list[dict[str, Any]] = json.load(fp, **options)
return [BenchmarkRecord.expand(benchmarks)]


def ndjson_save(records: Sequence[BenchmarkRecord], fp: IO, fdoptions: FileDriverOptions) -> None:
def ndjson_save(records: Sequence[BenchmarkRecord], fp: IO, options: dict[str, Any]) -> None:
import json

bm = []
for r in records:
bm += r.compact(mode=fdoptions.ctxmode)
fp.write("\n".join([json.dumps(b) for b in bm]))
bm += r.compact()
fp.write("\n".join([json.dumps(b, **options) for b in bm]))


def ndjson_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]:
def ndjson_load(fp: IO, options: dict[str, Any]) -> list[BenchmarkRecord]:
import json

benchmarks: list[dict[str, Any]]
benchmarks = [json.loads(line, **fdoptions.options) for line in fp]
benchmarks = [json.loads(line, **options) for line in fp]
return [BenchmarkRecord.expand(benchmarks)]


def csv_save(records: Sequence[BenchmarkRecord], fp: IO, fdoptions: FileDriverOptions) -> None:
def csv_save(records: Sequence[BenchmarkRecord], fp: IO, options: dict[str, Any]) -> None:
import csv

bm = []
for r in records:
bm += r.compact(mode=fdoptions.ctxmode)
writer = csv.DictWriter(fp, fieldnames=bm[0].keys(), **fdoptions.options)
bm += r.compact()
writer = csv.DictWriter(fp, fieldnames=bm[0].keys(), **options)
writer.writeheader()

for b in bm:
writer.writerow(b)


def csv_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]:
def csv_load(fp: IO, options: dict[str, Any]) -> list[BenchmarkRecord]:
import csv
import json

reader = csv.DictReader(fp, **fdoptions.options)
reader = csv.DictReader(fp, **options)

benchmarks: list[dict[str, Any]] = []
# apparently csv.DictReader has no appropriate type hint for __next__,
Expand All @@ -122,22 +111,22 @@ def csv_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]:
return [BenchmarkRecord.expand(benchmarks)]


def parquet_save(records: Sequence[BenchmarkRecord], fp: IO, fdoptions: FileDriverOptions) -> None:
def parquet_save(records: Sequence[BenchmarkRecord], fp: IO, options: dict[str, Any]) -> None:
import pyarrow as pa
import pyarrow.parquet as pq

bm = []
for r in records:
bm += r.compact(mode=fdoptions.ctxmode)
bm += r.compact()

table = pa.Table.from_pylist(bm)
pq.write_table(table, fp, **fdoptions.options)
pq.write_table(table, fp, **options)


def parquet_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]:
def parquet_load(fp: IO, options: dict[str, Any]) -> list[BenchmarkRecord]:
import pyarrow.parquet as pq

table = pq.read_table(fp, **fdoptions.options)
table = pq.read_table(fp, **options)
benchmarks: list[dict[str, Any]] = table.to_pylist()
return [BenchmarkRecord.expand(benchmarks)]

Expand Down Expand Up @@ -292,11 +281,8 @@ def read_batched(
else:
fd = open(file, mode)

# dummy value, since the context mode is unused in read ops.
fdoptions = FileDriverOptions(ctxmode="omit", options=options or {})

with fd as fp:
return de(fp, fdoptions)
return de(fp, options or {})

def write(
self,
Expand All @@ -305,7 +291,6 @@ def write(
mode: str = "w",
driver: str | None = None,
compression: str | None = None,
ctxmode: Literal["flatten", "inline", "omit"] = "inline",
options: dict[str, Any] | None = None,
) -> None:
"""Greedy version of ``FileIO.write_batched()``"""
Expand All @@ -315,7 +300,6 @@ def write(
mode=mode,
driver=driver,
compression=compression,
ctxmode=ctxmode,
options=options,
)

Expand All @@ -326,7 +310,6 @@ def write_batched(
mode: str = "w",
driver: str | None = None,
compression: str | None = None,
ctxmode: Literal["flatten", "inline", "omit"] = "inline",
options: dict[str, Any] | None = None,
) -> None:
"""
Expand All @@ -348,8 +331,6 @@ def write_batched(
compression: str | None
Compression engine to use. If None, the compression inferred from the given
file path's extension will be used.
ctxmode: Literal["flatten", "inline", "omit"]
How to handle the benchmark context when writing the record data.
options: dict[str, Any] | None
Options to pass to the respective file driver implementation.
Expand Down Expand Up @@ -382,9 +363,8 @@ def write_batched(
else:
fd = open(file, mode)

fdoptions = FileDriverOptions(ctxmode=ctxmode, options=options or {})
with fd as fp:
ser(records, fp, fdoptions)
ser(records, fp, options or {})


class FileReporter(FileIO, BenchmarkReporter):
Expand Down
31 changes: 8 additions & 23 deletions src/nnbench/types/benchmark.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from collections.abc import Callable, Mapping
from dataclasses import dataclass, field
from types import MappingProxyType
from typing import Any, Literal
from typing import Any

if sys.version_info >= (3, 11):
from typing import Self
Expand All @@ -32,32 +32,23 @@ class BenchmarkRecord:
context: dict[str, Any]
benchmarks: list[dict[str, Any]]

def compact(
self,
mode: Literal["flatten", "inline", "omit"] = "inline",
sep: str = ".",
) -> list[dict[str, Any]]:
def compact(self, sep: str = ".") -> list[dict[str, Any]]:
"""
Prepare the benchmark results, optionally inlining the context either as a
nested dictionary or in flattened form.
Parameters
----------
mode: Literal["flatten", "inline", "omit"]
How to handle the context. ``"omit"`` leaves out the context entirely, ``"inline"``
inserts it into the benchmark dictionary as a single entry named ``"context"``, and
``"flatten"`` inserts the flattened context values into the dictionary.
sep: str
The separator to use when flattening the context, i.e. when ``mode = "flatten"``.
Deprecated, unused.
Returns
-------
list[dict[str, Any]]
The updated list of benchmark records.
"""
if mode == "omit":
return self.benchmarks

# TODO: Allow keeping data as top-level struct?
# i.e. .compact(inline=False) -> { "context": {...}, "benchmarks": [...] }
result = []

for b in self.benchmarks:
Expand Down Expand Up @@ -86,18 +77,12 @@ def expand(cls, bms: list[dict[str, Any]]) -> Self:
"""
ctx: dict[str, Any] = {}
for b in bms:
# Safeguard if the context is not in the deser'd record,
# for example if the record came from a DB query.
if "context" in b:
ctx = b.pop("context")
elif "_contextkeys" in b:
ctxkeys = b.pop("_contextkeys")
for k in ctxkeys:
# This should never throw, save for data corruption.
ctx[k] = b.pop(k)
ctx = b.pop("context", {})
return cls(context=ctx, benchmarks=bms)

# TODO: Add an expandmany() API for returning a sequence of records for heterogeneous
# context data.


@dataclass(frozen=True)
class Benchmark:
Expand Down
12 changes: 4 additions & 8 deletions tests/test_fileio.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
import itertools
from pathlib import Path
from typing import Literal

import pytest

Expand All @@ -9,12 +7,10 @@


@pytest.mark.parametrize(
"ext,ctxmode",
itertools.product(["yaml", "json", "ndjson", "csv", "parquet"], ["inline", "flatten"]),
"ext",
["yaml", "json", "ndjson", "csv", "parquet"],
)
def test_fileio_writes_no_compression_inline(
tmp_path: Path, ext: str, ctxmode: Literal["inline", "flatten"]
) -> None:
def test_fileio_writes_no_compression_inline(tmp_path: Path, ext: str) -> None:
"""Tests data integrity for file IO roundtrips with both context modes."""
f = FileIO()

Expand All @@ -24,7 +20,7 @@ def test_fileio_writes_no_compression_inline(
)
file = tmp_path / f"record.{ext}"
writemode = "wb" if ext == "parquet" else "w"
f.write(rec, file, mode=writemode, ctxmode="inline")
f.write(rec, file, mode=writemode)
readmode = "rb" if ext == "parquet" else "r"
rec2 = f.read(file, mode=readmode)
# Python stdlib csv coerces everything to string.
Expand Down

0 comments on commit d51e11f

Please sign in to comment.