Skip to content

Commit

Permalink
Add nnbench.reporter module, first tabular console implementation
Browse files Browse the repository at this point in the history
The interface is not yet stable, since there are no options defined on
the class constructor, and the console does not take any arguments yet.
  • Loading branch information
nicholasjng committed Jan 18, 2024
1 parent e37acb0 commit 2c4138e
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 5 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ strict_optional = false
warn_unreachable = true

[[tool.mypy.overrides]]
module = ["yaml"]
module = ["tabulate", "yaml"]
ignore_missing_imports = true

[tool.ruff]
Expand Down
4 changes: 3 additions & 1 deletion src/nnbench/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,6 @@
pass

# TODO: This naming is unfortunate
from .core import Benchmark, Params, benchmark, parametrize
from .core import benchmark, parametrize
from .reporter import BaseReporter
from .types import Benchmark, Params
56 changes: 56 additions & 0 deletions src/nnbench/reporter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import sys
import types
from typing import Any

from nnbench.types import BenchmarkResult


class BaseReporter:
"""
The base interface for a benchmark reporter class.
A benchmark reporter consumes benchmark results from a run, and subsequently
reports them in the way specified by the respective implementation's `report()`
method.
For example, to write benchmark results to a database, you could save the credentials
for authentication in the class constructor, and then stream the results directly to
the database in `report()`, with preprocessing if necessary.
Parameters
----------
**kwargs: Any
Additional keyword arguments, for compatibility with subclass interfaces.
"""

def __init__(self, **kwargs: Any):
pass

def report(self, result: BenchmarkResult) -> None:
raise NotImplementedError


class ConsoleReporter(BaseReporter):
# TODO: Implement regex filters, context values, display options, ... (__init__)
def report(self, result: BenchmarkResult) -> None:
try:
from tabulate import tabulate
except ModuleNotFoundError:
raise ValueError(
f"{self.__class__.__name__} requires `tabulate` to be installed. "
f"To install, run `{sys.executable} -m pip install --upgrade tabulate`."
)

benchmarks = result["benchmarks"]
print(tabulate(benchmarks, headers="keys"))


# internal, mutable
_reporter_registry: dict[str, type[BaseReporter]] = {
"console": ConsoleReporter,
}

# external, immutable
reporter_registry: types.MappingProxyType[str, type[BaseReporter]] = types.MappingProxyType(
_reporter_registry
)
39 changes: 36 additions & 3 deletions src/nnbench/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any, Sequence

from nnbench.context import ContextProvider
from nnbench.reporter import BaseReporter, reporter_registry
from nnbench.types import Benchmark, BenchmarkResult
from nnbench.util import import_file_as_module, ismodule

Expand Down Expand Up @@ -161,6 +162,38 @@ def run(
benchmarks=results,
)

def report(self) -> None:
"""Report collected results from a previous run."""
raise NotImplementedError
def report(
self, to: str | BaseReporter | Sequence[str | BaseReporter], result: BenchmarkResult
) -> None:
"""
Report collected results from a previous run.
Parameters
----------
to: str | BaseReporter | Sequence[str | BaseReporter]
The reporter to use when reporting / streaming results. Can be either a string
(which prompts a lookup of all nnbench native reporters), a reporter instance,
or a sequence thereof, which enables streaming result data to multiple sinks.
result: BenchmarkResult
The benchmark result to report.
"""

def load_reporter(r: str | BaseReporter) -> BaseReporter:
if isinstance(r, str):
try:
return reporter_registry[r]()
except KeyError:
# TODO: Add a note on nnbench reporter entry point once supported
raise KeyError(f"unknown reporter class {r!r}")
else:
return r

dests: tuple[BaseReporter, ...] = ()

if isinstance(to, (str, BaseReporter)):
dests += (load_reporter(to),)
else:
dests += tuple(load_reporter(t) for t in to)

for reporter in dests:
reporter.report(result)

0 comments on commit 2c4138e

Please sign in to comment.