Skip to content

Commit

Permalink
Add Benchmark.to_list() and to_json() to improve record IO (#162)
Browse files Browse the repository at this point in the history
* Delete stale tranforms doc

* Change benchmark record export capabilities, update file IO

It is useful to be able to export a record to both a struct (JSON, YAML)
and to a list of results (tabular formats like CSV, Parquet, databases).

The previous way of choosing either mode via a string literal was not
great, so now, we restructure the API to `BenchmarkRecord.to_{json,list}`
for either export, and make BenchmarkRecord.expand() capable of dealing
with either.

* Run pre-commit autoupdate

* Remove transforms doc from mkdocs.yml
  • Loading branch information
nicholasjng authored Nov 13, 2024
1 parent eba9340 commit f5792fa
Show file tree
Hide file tree
Showing 11 changed files with 84 additions and 189 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@ repos:
types_or: [ python, pyi ]
args: [--ignore-missing-imports, --explicit-package-bases]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.7.1
rev: v0.7.3
hooks:
- id: ruff
args: [--fix, --exit-non-zero-on-fix]
- id: ruff-format
- repo: https://github.com/astral-sh/uv-pre-commit
rev: 0.4.27
rev: 0.5.1
hooks:
- id: uv-lock
name: Lock project dependencies
61 changes: 0 additions & 61 deletions docs/guides/transforms.md

This file was deleted.

4 changes: 1 addition & 3 deletions examples/bq/bq.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,7 @@ def main():
runner = nnbench.BenchmarkRunner()
res = runner.run("benchmarks.py", params={"a": 1, "b": 1}, context=(GitEnvironmentInfo(),))

load_job = client.load_table_from_json(
res.compact(mode="flatten", sep="_"), table_id, job_config=job_config
)
load_job = client.load_table_from_json(res.to_json(), table_id, job_config=job_config)
load_job.result()


Expand Down
2 changes: 1 addition & 1 deletion examples/mnist/mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class MNISTTestParameters(nnbench.Parameters):


class ConvNet(nn.Module):
@nn.compact
@nn.to_json
def __call__(self, x):
x = nn.Conv(features=32, kernel_size=(3, 3))(x)
x = nn.relu(x)
Expand Down
2 changes: 1 addition & 1 deletion examples/prefect/src/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ async def write(
) -> None:
await create_table_artifact(
key=key,
table=record.compact(mode="flatten"),
table=record.to_json(),
description=description,
)

Expand Down
1 change: 0 additions & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ nav:
- guides/organization.md
- guides/runners.md
- guides/memoization.md
- guides/transforms.md
- Examples:
- tutorials/index.md
- tutorials/huggingface.md
Expand Down
2 changes: 1 addition & 1 deletion src/nnbench/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,8 @@ def git_subprocess(args: list[str]) -> subprocess.CompletedProcess:
p = git_subprocess(["remote", "get-url", self.remote])
if not p.returncode:
remotename: str = p.stdout.strip()
# it's an SSH remote.
if "@" in remotename:
# it's an SSH remote.
prefix, sep = "git@", ":"
else:
# it is HTTPS.
Expand Down
3 changes: 2 additions & 1 deletion src/nnbench/reporter/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

from nnbench.reporter.util import nullcols
from nnbench.types import BenchmarkRecord
from nnbench.util import flatten


# TODO: Add IO mixins for database, file, and HTTP IO
Expand Down Expand Up @@ -107,7 +108,7 @@ def display(
continue
filteredctx = {
k: v
for k, v in record.context.items()
for k, v in flatten(record.context).items()
if any(k.startswith(i) for i in include_context)
}
filteredbm = {k: v for k, v in bm.items() if k in cols}
Expand Down
125 changes: 38 additions & 87 deletions src/nnbench/reporter/file.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import threading
from collections.abc import Callable, Sequence
from collections.abc import Callable
from pathlib import Path
from typing import IO, Any, Literal, cast

Expand All @@ -9,8 +9,8 @@

_Options = dict[str, Any]
SerDe = tuple[
Callable[[Sequence[BenchmarkRecord], IO, dict[str, Any]], None],
Callable[[IO, dict[str, Any]], list[BenchmarkRecord]],
Callable[[BenchmarkRecord, IO, dict[str, Any]], None],
Callable[[IO, dict[str, Any]], BenchmarkRecord],
]


Expand All @@ -20,75 +20,67 @@
_compression_lock = threading.Lock()


def yaml_save(records: Sequence[BenchmarkRecord], fp: IO, options: dict[str, Any]) -> None:
def yaml_save(record: BenchmarkRecord, fp: IO, options: dict[str, Any]) -> None:
try:
import yaml
except ImportError:
raise ModuleNotFoundError("`pyyaml` is not installed")

bms = []
for r in records:
bms += r.compact()
yaml.safe_dump(bms, fp, **options)
yaml.safe_dump(record.to_json(), fp, **options)


def yaml_load(fp: IO, options: dict[str, Any]) -> list[BenchmarkRecord]:
def yaml_load(fp: IO, options: dict[str, Any]) -> BenchmarkRecord:
try:
import yaml
except ImportError:
raise ModuleNotFoundError("`pyyaml` is not installed")

bms = yaml.safe_load(fp)
return [BenchmarkRecord.expand(bms)]
return BenchmarkRecord.expand(bms)


def json_save(records: Sequence[BenchmarkRecord], fp: IO, options: dict[str, Any]) -> None:
def json_save(record: BenchmarkRecord, fp: IO, options: dict[str, Any]) -> None:
import json

bm = []
for r in records:
bm += r.compact()
json.dump(bm, fp, **options)
json.dump(record.to_json(), fp, **options)


def json_load(fp: IO, options: dict[str, Any]) -> list[BenchmarkRecord]:
def json_load(fp: IO, options: dict[str, Any]) -> BenchmarkRecord:
import json

benchmarks: list[dict[str, Any]] = json.load(fp, **options)
return [BenchmarkRecord.expand(benchmarks)]
benchmarks = json.load(fp, **options)
return BenchmarkRecord.expand(benchmarks)


def ndjson_save(records: Sequence[BenchmarkRecord], fp: IO, options: dict[str, Any]) -> None:
def ndjson_save(record: BenchmarkRecord, fp: IO, options: dict[str, Any]) -> None:
# mode is unused, since NDJSON requires every individual benchmark to be one line.
import json

bm = []
for r in records:
bm += r.compact()
fp.write("\n".join([json.dumps(b, **options) for b in bm]))
bms = record.to_list()
fp.write("\n".join([json.dumps(b, **options) for b in bms]))


def ndjson_load(fp: IO, options: dict[str, Any]) -> list[BenchmarkRecord]:
def ndjson_load(fp: IO, options: dict[str, Any]) -> BenchmarkRecord:
import json

benchmarks: list[dict[str, Any]]
benchmarks = [json.loads(line, **options) for line in fp]
return [BenchmarkRecord.expand(benchmarks)]
return BenchmarkRecord.expand(benchmarks)


def csv_save(records: Sequence[BenchmarkRecord], fp: IO, options: dict[str, Any]) -> None:
def csv_save(record: BenchmarkRecord, fp: IO, options: dict[str, Any]) -> None:
# mode is unused, since NDJSON requires every individual benchmark to be one line.
import csv

bm = []
for r in records:
bm += r.compact()
bm = record.to_list()
writer = csv.DictWriter(fp, fieldnames=bm[0].keys(), **options)
writer.writeheader()

for b in bm:
writer.writerow(b)


def csv_load(fp: IO, options: dict[str, Any]) -> list[BenchmarkRecord]:
def csv_load(fp: IO, options: dict[str, Any]) -> BenchmarkRecord:
import csv
import json

Expand All @@ -102,33 +94,28 @@ def csv_load(fp: IO, options: dict[str, Any]) -> list[BenchmarkRecord]:
benchmarks.append(bm)
# it can happen that the context is inlined as a stringified JSON object
# (e.g. in CSV), so we optionally JSON-load the context.
for key in ("context", "_contextkeys"):
if key in bm:
strctx: str = bm[key]
# TODO: This does not play nicely with doublequote, maybe re.sub?
strctx = strctx.replace("'", '"')
bm[key] = json.loads(strctx)
return [BenchmarkRecord.expand(benchmarks)]
if "context" in bm:
strctx: str = bm["context"]
# TODO: This does not play nicely with doublequote, maybe re.sub?
strctx = strctx.replace("'", '"')
bm["context"] = json.loads(strctx)
return BenchmarkRecord.expand(benchmarks)


def parquet_save(records: Sequence[BenchmarkRecord], fp: IO, options: dict[str, Any]) -> None:
def parquet_save(record: BenchmarkRecord, fp: IO, options: dict[str, Any]) -> None:
import pyarrow as pa
import pyarrow.parquet as pq

bm = []
for r in records:
bm += r.compact()

table = pa.Table.from_pylist(bm)
table = pa.Table.from_pylist(record.to_list())
pq.write_table(table, fp, **options)


def parquet_load(fp: IO, options: dict[str, Any]) -> list[BenchmarkRecord]:
def parquet_load(fp: IO, options: dict[str, Any]) -> BenchmarkRecord:
import pyarrow.parquet as pq

table = pq.read_table(fp, **options)
benchmarks: list[dict[str, Any]] = table.to_pylist()
return [BenchmarkRecord.expand(benchmarks)]
return BenchmarkRecord.expand(benchmarks)


def get_driver_implementation(name: str) -> SerDe:
Expand Down Expand Up @@ -210,26 +197,9 @@ def read(
options: dict[str, Any] | None = None,
) -> BenchmarkRecord:
"""
Greedy version of ``FileIO.read_batched()``, returning the first read record.
When reading a multi-record file, this uses as much memory as the batched version.
"""
records = self.read_batched(
file=file, mode=mode, driver=driver, compression=compression, options=options
)
return records[0]
Reads a benchmark record from the given file path.
def read_batched(
self,
file: str | os.PathLike[str],
mode: str = "r",
driver: str | None = None,
compression: str | None = None,
options: dict[str, Any] | None = None,
) -> list[BenchmarkRecord]:
"""
Reads a set of benchmark records from the given file path.
The file driver is chosen based on the extension found on the ``file`` path.
The file driver is chosen based on the extension in the ``file`` pathname.
Parameters
----------
Expand All @@ -248,8 +218,8 @@ def read_batched(
Returns
-------
list[BenchmarkRecord]
The benchmark records contained in the file.
BenchmarkRecord
The benchmark record contained in the file.
Raises
------
Expand Down Expand Up @@ -292,25 +262,6 @@ def write(
driver: str | None = None,
compression: str | None = None,
options: dict[str, Any] | None = None,
) -> None:
"""Greedy version of ``FileIO.write_batched()``"""
self.write_batched(
[record],
file=file,
mode=mode,
driver=driver,
compression=compression,
options=options,
)

def write_batched(
self,
records: Sequence[BenchmarkRecord],
file: str | os.PathLike[str],
mode: str = "w",
driver: str | None = None,
compression: str | None = None,
options: dict[str, Any] | None = None,
) -> None:
"""
Writes a benchmark record to the given file path.
Expand All @@ -319,7 +270,7 @@ def write_batched(
Parameters
----------
records: Sequence[BenchmarkRecord]
record: BenchmarkRecord
The record to write to the database.
file: str | os.PathLike[str]
The file name to write to.
Expand Down Expand Up @@ -364,7 +315,7 @@ def write_batched(
fd = open(file, mode)

with fd as fp:
ser(records, fp, options or {})
ser(record, fp, options or {})


class FileReporter(FileIO, BenchmarkReporter):
Expand Down
Loading

0 comments on commit f5792fa

Please sign in to comment.