From 2c4138e1c8c4f2858ca25de8e3e9a59bd190e2e7 Mon Sep 17 00:00:00 2001 From: Nicholas Junge Date: Thu, 18 Jan 2024 14:46:48 +0100 Subject: [PATCH] Add `nnbench.reporter` module, first tabular console implementation The interface is not yet stable, since there are no options defined on the class constructor, and the console does not take any arguments yet. --- pyproject.toml | 2 +- src/nnbench/__init__.py | 4 ++- src/nnbench/reporter.py | 56 +++++++++++++++++++++++++++++++++++++++++ src/nnbench/runner.py | 39 +++++++++++++++++++++++++--- 4 files changed, 96 insertions(+), 5 deletions(-) create mode 100644 src/nnbench/reporter.py diff --git a/pyproject.toml b/pyproject.toml index dd2af88..730417c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -83,7 +83,7 @@ strict_optional = false warn_unreachable = true [[tool.mypy.overrides]] -module = ["yaml"] +module = ["tabulate", "yaml"] ignore_missing_imports = true [tool.ruff] diff --git a/src/nnbench/__init__.py b/src/nnbench/__init__.py index 077d727..00458de 100644 --- a/src/nnbench/__init__.py +++ b/src/nnbench/__init__.py @@ -9,4 +9,6 @@ pass # TODO: This naming is unfortunate -from .core import Benchmark, Params, benchmark, parametrize +from .core import benchmark, parametrize +from .reporter import BaseReporter +from .types import Benchmark, Params diff --git a/src/nnbench/reporter.py b/src/nnbench/reporter.py new file mode 100644 index 0000000..b6d8545 --- /dev/null +++ b/src/nnbench/reporter.py @@ -0,0 +1,56 @@ +import sys +import types +from typing import Any + +from nnbench.types import BenchmarkResult + + +class BaseReporter: + """ + The base interface for a benchmark reporter class. + + A benchmark reporter consumes benchmark results from a run, and subsequently + reports them in the way specified by the respective implementation's `report()` + method. + + For example, to write benchmark results to a database, you could save the credentials + for authentication in the class constructor, and then stream the results directly to + the database in `report()`, with preprocessing if necessary. + + Parameters + ---------- + **kwargs: Any + Additional keyword arguments, for compatibility with subclass interfaces. + """ + + def __init__(self, **kwargs: Any): + pass + + def report(self, result: BenchmarkResult) -> None: + raise NotImplementedError + + +class ConsoleReporter(BaseReporter): + # TODO: Implement regex filters, context values, display options, ... (__init__) + def report(self, result: BenchmarkResult) -> None: + try: + from tabulate import tabulate + except ModuleNotFoundError: + raise ValueError( + f"{self.__class__.__name__} requires `tabulate` to be installed. " + f"To install, run `{sys.executable} -m pip install --upgrade tabulate`." + ) + + benchmarks = result["benchmarks"] + print(tabulate(benchmarks, headers="keys")) + + +# internal, mutable +_reporter_registry: dict[str, type[BaseReporter]] = { + "console": ConsoleReporter, +} + +# external, immutable +reporter_registry: types.MappingProxyType[str, type[BaseReporter]] = types.MappingProxyType( + _reporter_registry +) diff --git a/src/nnbench/runner.py b/src/nnbench/runner.py index 53fffcf..400e376 100644 --- a/src/nnbench/runner.py +++ b/src/nnbench/runner.py @@ -8,6 +8,7 @@ from typing import Any, Sequence from nnbench.context import ContextProvider +from nnbench.reporter import BaseReporter, reporter_registry from nnbench.types import Benchmark, BenchmarkResult from nnbench.util import import_file_as_module, ismodule @@ -161,6 +162,38 @@ def run( benchmarks=results, ) - def report(self) -> None: - """Report collected results from a previous run.""" - raise NotImplementedError + def report( + self, to: str | BaseReporter | Sequence[str | BaseReporter], result: BenchmarkResult + ) -> None: + """ + Report collected results from a previous run. + + Parameters + ---------- + to: str | BaseReporter | Sequence[str | BaseReporter] + The reporter to use when reporting / streaming results. Can be either a string + (which prompts a lookup of all nnbench native reporters), a reporter instance, + or a sequence thereof, which enables streaming result data to multiple sinks. + result: BenchmarkResult + The benchmark result to report. + """ + + def load_reporter(r: str | BaseReporter) -> BaseReporter: + if isinstance(r, str): + try: + return reporter_registry[r]() + except KeyError: + # TODO: Add a note on nnbench reporter entry point once supported + raise KeyError(f"unknown reporter class {r!r}") + else: + return r + + dests: tuple[BaseReporter, ...] = () + + if isinstance(to, (str, BaseReporter)): + dests += (load_reporter(to),) + else: + dests += tuple(load_reporter(t) for t in to) + + for reporter in dests: + reporter.report(result)