From 84a31c3835c2d320967b78d962695cdf31a757d4 Mon Sep 17 00:00:00 2001 From: Nicholas Junge Date: Mon, 11 Nov 2024 16:46:57 +0100 Subject: [PATCH] Remove context processing options The context struct can be passed directly in all options we currently consider, so we don't give the option anymore. --- src/nnbench/reporter/file.py | 76 +++++++++++++--------------------- src/nnbench/types/benchmark.py | 25 ++--------- tests/test_fileio.py | 12 ++---- 3 files changed, 36 insertions(+), 77 deletions(-) diff --git a/src/nnbench/reporter/file.py b/src/nnbench/reporter/file.py index 835289c6..c457900a 100644 --- a/src/nnbench/reporter/file.py +++ b/src/nnbench/reporter/file.py @@ -1,26 +1,16 @@ import os import threading from collections.abc import Callable, Sequence -from dataclasses import dataclass, field from pathlib import Path from typing import IO, Any, Literal, cast from nnbench.reporter.base import BenchmarkReporter from nnbench.types import BenchmarkRecord - -@dataclass(frozen=True) -class FileDriverOptions: - options: dict[str, Any] = field(default_factory=dict) - """Options to pass to the underlying serialization API call, e.g. ``json.dump``.""" - ctxmode: Literal["flatten", "inline", "omit"] = "inline" - """How to handle the context struct.""" - - _Options = dict[str, Any] SerDe = tuple[ - Callable[[Sequence[BenchmarkRecord], IO, FileDriverOptions], None], - Callable[[IO, FileDriverOptions], list[BenchmarkRecord]], + Callable[[Sequence[BenchmarkRecord], IO, dict[str, Any]], None], + Callable[[IO, dict[str, Any]], list[BenchmarkRecord]], ] @@ -30,7 +20,7 @@ class FileDriverOptions: _compression_lock = threading.Lock() -def yaml_save(records: Sequence[BenchmarkRecord], fp: IO, fdoptions: FileDriverOptions) -> None: +def yaml_save(records: Sequence[BenchmarkRecord], fp: IO, options: dict[str, Any]) -> None: try: import yaml except ImportError: @@ -38,72 +28,71 @@ def yaml_save(records: Sequence[BenchmarkRecord], fp: IO, fdoptions: FileDriverO bms = [] for r in records: - bms += r.compact(mode=fdoptions.ctxmode) - yaml.safe_dump(bms, fp, **fdoptions.options) + bms += r.compact() + yaml.safe_dump(bms, fp, **options) -def yaml_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]: +def yaml_load(fp: IO, options: dict[str, Any]) -> list[BenchmarkRecord]: try: import yaml except ImportError: raise ModuleNotFoundError("`pyyaml` is not installed") - # TODO: Use expandmany() bms = yaml.safe_load(fp) return [BenchmarkRecord.expand(bms)] -def json_save(records: Sequence[BenchmarkRecord], fp: IO, fdoptions: FileDriverOptions) -> None: +def json_save(records: Sequence[BenchmarkRecord], fp: IO, options: dict[str, Any]) -> None: import json bm = [] for r in records: - bm += r.compact(mode=fdoptions.ctxmode) - json.dump(bm, fp, **fdoptions.options) + bm += r.compact() + json.dump(bm, fp, **options) -def json_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]: +def json_load(fp: IO, options: dict[str, Any]) -> list[BenchmarkRecord]: import json - benchmarks: list[dict[str, Any]] = json.load(fp, **fdoptions.options) + benchmarks: list[dict[str, Any]] = json.load(fp, **options) return [BenchmarkRecord.expand(benchmarks)] -def ndjson_save(records: Sequence[BenchmarkRecord], fp: IO, fdoptions: FileDriverOptions) -> None: +def ndjson_save(records: Sequence[BenchmarkRecord], fp: IO, options: dict[str, Any]) -> None: import json bm = [] for r in records: - bm += r.compact(mode=fdoptions.ctxmode) - fp.write("\n".join([json.dumps(b) for b in bm])) + bm += r.compact() + fp.write("\n".join([json.dumps(b, **options) for b in bm])) -def ndjson_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]: +def ndjson_load(fp: IO, options: dict[str, Any]) -> list[BenchmarkRecord]: import json benchmarks: list[dict[str, Any]] - benchmarks = [json.loads(line, **fdoptions.options) for line in fp] + benchmarks = [json.loads(line, **options) for line in fp] return [BenchmarkRecord.expand(benchmarks)] -def csv_save(records: Sequence[BenchmarkRecord], fp: IO, fdoptions: FileDriverOptions) -> None: +def csv_save(records: Sequence[BenchmarkRecord], fp: IO, options: dict[str, Any]) -> None: import csv bm = [] for r in records: - bm += r.compact(mode=fdoptions.ctxmode) - writer = csv.DictWriter(fp, fieldnames=bm[0].keys(), **fdoptions.options) + bm += r.compact() + writer = csv.DictWriter(fp, fieldnames=bm[0].keys(), **options) writer.writeheader() for b in bm: writer.writerow(b) -def csv_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]: +def csv_load(fp: IO, options: dict[str, Any]) -> list[BenchmarkRecord]: import csv import json - reader = csv.DictReader(fp, **fdoptions.options) + reader = csv.DictReader(fp, **options) benchmarks: list[dict[str, Any]] = [] # apparently csv.DictReader has no appropriate type hint for __next__, @@ -122,22 +111,22 @@ def csv_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]: return [BenchmarkRecord.expand(benchmarks)] -def parquet_save(records: Sequence[BenchmarkRecord], fp: IO, fdoptions: FileDriverOptions) -> None: +def parquet_save(records: Sequence[BenchmarkRecord], fp: IO, options: dict[str, Any]) -> None: import pyarrow as pa import pyarrow.parquet as pq bm = [] for r in records: - bm += r.compact(mode=fdoptions.ctxmode) + bm += r.compact() table = pa.Table.from_pylist(bm) - pq.write_table(table, fp, **fdoptions.options) + pq.write_table(table, fp, **options) -def parquet_load(fp: IO, fdoptions: FileDriverOptions) -> list[BenchmarkRecord]: +def parquet_load(fp: IO, options: dict[str, Any]) -> list[BenchmarkRecord]: import pyarrow.parquet as pq - table = pq.read_table(fp, **fdoptions.options) + table = pq.read_table(fp, **options) benchmarks: list[dict[str, Any]] = table.to_pylist() return [BenchmarkRecord.expand(benchmarks)] @@ -292,11 +281,8 @@ def read_batched( else: fd = open(file, mode) - # dummy value, since the context mode is unused in read ops. - fdoptions = FileDriverOptions(ctxmode="omit", options=options or {}) - with fd as fp: - return de(fp, fdoptions) + return de(fp, options or {}) def write( self, @@ -305,7 +291,6 @@ def write( mode: str = "w", driver: str | None = None, compression: str | None = None, - ctxmode: Literal["flatten", "inline", "omit"] = "inline", options: dict[str, Any] | None = None, ) -> None: """Greedy version of ``FileIO.write_batched()``""" @@ -315,7 +300,6 @@ def write( mode=mode, driver=driver, compression=compression, - ctxmode=ctxmode, options=options, ) @@ -326,7 +310,6 @@ def write_batched( mode: str = "w", driver: str | None = None, compression: str | None = None, - ctxmode: Literal["flatten", "inline", "omit"] = "inline", options: dict[str, Any] | None = None, ) -> None: """ @@ -348,8 +331,6 @@ def write_batched( compression: str | None Compression engine to use. If None, the compression inferred from the given file path's extension will be used. - ctxmode: Literal["flatten", "inline", "omit"] - How to handle the benchmark context when writing the record data. options: dict[str, Any] | None Options to pass to the respective file driver implementation. @@ -382,9 +363,8 @@ def write_batched( else: fd = open(file, mode) - fdoptions = FileDriverOptions(ctxmode=ctxmode, options=options or {}) with fd as fp: - ser(records, fp, fdoptions) + ser(records, fp, options or {}) class FileReporter(FileIO, BenchmarkReporter): diff --git a/src/nnbench/types/benchmark.py b/src/nnbench/types/benchmark.py index 83c7216a..22f86e69 100644 --- a/src/nnbench/types/benchmark.py +++ b/src/nnbench/types/benchmark.py @@ -5,7 +5,7 @@ from collections.abc import Callable, Mapping from dataclasses import dataclass, field from types import MappingProxyType -from typing import Any, Literal +from typing import Any if sys.version_info >= (3, 11): from typing import Self @@ -32,21 +32,13 @@ class BenchmarkRecord: context: dict[str, Any] benchmarks: list[dict[str, Any]] - def compact( - self, - mode: Literal["flatten", "inline", "omit"] = "inline", - sep: str = ".", - ) -> list[dict[str, Any]]: + def compact(self, sep: str = ".") -> list[dict[str, Any]]: """ Prepare the benchmark results, optionally inlining the context either as a nested dictionary or in flattened form. Parameters ---------- - mode: Literal["flatten", "inline", "omit"] - How to handle the context. ``"omit"`` leaves out the context entirely, ``"inline"`` - inserts it into the benchmark dictionary as a single entry named ``"context"``, and - ``"flatten"`` inserts the flattened context values into the dictionary. sep: str The separator to use when flattening the context, i.e. when ``mode = "flatten"``. @@ -55,9 +47,6 @@ def compact( list[dict[str, Any]] The updated list of benchmark records. """ - if mode == "omit": - return self.benchmarks - result = [] for b in self.benchmarks: @@ -86,18 +75,12 @@ def expand(cls, bms: list[dict[str, Any]]) -> Self: """ ctx: dict[str, Any] = {} for b in bms: + # Safeguard if the context is not in the deser'd record, + # for example if the record came from a DB query. if "context" in b: ctx = b.pop("context") - elif "_contextkeys" in b: - ctxkeys = b.pop("_contextkeys") - for k in ctxkeys: - # This should never throw, save for data corruption. - ctx[k] = b.pop(k) return cls(context=ctx, benchmarks=bms) - # TODO: Add an expandmany() API for returning a sequence of records for heterogeneous - # context data. - @dataclass(frozen=True) class Benchmark: diff --git a/tests/test_fileio.py b/tests/test_fileio.py index 16988d5a..04f87955 100644 --- a/tests/test_fileio.py +++ b/tests/test_fileio.py @@ -1,6 +1,4 @@ -import itertools from pathlib import Path -from typing import Literal import pytest @@ -9,12 +7,10 @@ @pytest.mark.parametrize( - "ext,ctxmode", - itertools.product(["yaml", "json", "ndjson", "csv", "parquet"], ["inline", "flatten"]), + "ext", + ["yaml", "json", "ndjson", "csv", "parquet"], ) -def test_fileio_writes_no_compression_inline( - tmp_path: Path, ext: str, ctxmode: Literal["inline", "flatten"] -) -> None: +def test_fileio_writes_no_compression_inline(tmp_path: Path, ext: str) -> None: """Tests data integrity for file IO roundtrips with both context modes.""" f = FileIO() @@ -24,7 +20,7 @@ def test_fileio_writes_no_compression_inline( ) file = tmp_path / f"record.{ext}" writemode = "wb" if ext == "parquet" else "w" - f.write(rec, file, mode=writemode, ctxmode="inline") + f.write(rec, file, mode=writemode) readmode = "rb" if ext == "parquet" else "r" rec2 = f.read(file, mode=readmode) # Python stdlib csv coerces everything to string.