diff --git a/src/nnbench/runner.py b/src/nnbench/runner.py index 400e376..c9a78cb 100644 --- a/src/nnbench/runner.py +++ b/src/nnbench/runner.py @@ -1,20 +1,44 @@ """The abstract benchmark runner interface, which can be overridden for custom benchmark workloads.""" from __future__ import annotations +import inspect import logging import os import sys +from dataclasses import asdict from pathlib import Path from typing import Any, Sequence from nnbench.context import ContextProvider from nnbench.reporter import BaseReporter, reporter_registry -from nnbench.types import Benchmark, BenchmarkResult +from nnbench.types import Benchmark, BenchmarkResult, Params from nnbench.util import import_file_as_module, ismodule logger = logging.getLogger(__name__) +def _check(params: dict[str, Any], benchmarks: list[Benchmark]) -> None: + param_types = {k: type(v) for k, v in params.items()} + benchmark_interface: dict[str, inspect.Parameter] = {} + for bm in benchmarks: + for name, param in inspect.signature(bm.fn).parameters.items(): + param_type = param.annotation + if name in benchmark_interface and benchmark_interface[name].annotation != param_type: + orig_type = benchmark_interface[name].annotation + raise TypeError( + f"got non-unique types {orig_type}, {param_type} for parameter {name!r}" + ) + benchmark_interface[name] = param + for name, param in benchmark_interface.items(): + if name not in param_types and param.default == inspect.Parameter.empty: + raise ValueError(f"missing value for required parameter {name!r}") + if not issubclass(param_types[name], param.annotation): + raise TypeError( + f"expected type {param.annotation} for parameter {name!r}, " + f"got {param_types[name]!r}" + ) + + def iscontainer(s: Any) -> bool: return isinstance(s, (tuple, list)) @@ -90,11 +114,11 @@ def collect( def run( self, - path_or_module: str | os.PathLike[str] = "__main__", - params: dict[str, Any] | None = None, + path_or_module: str | os.PathLike[str], + params: dict[str, Any] | Params, tags: tuple[str, ...] = (), context: Sequence[ContextProvider] = (), - ) -> BenchmarkResult: + ) -> BenchmarkResult | None: """ Run a previously collected benchmark workload. @@ -103,7 +127,7 @@ def run( path_or_module: str | os.PathLike[str] Name or path of the module to discover benchmarks in. Can also be a directory, in which case benchmarks are collected from the Python files therein. - params: dict[str, Any] | None + params: dict[str, Any] | Params Parameters to use for the benchmark run. Names have to match positional and keyword argument names of the benchmark functions. tags: tuple[str, ...] @@ -116,7 +140,7 @@ def run( Returns ------- - BenchmarkResult + BenchmarkResult | None A JSON output representing the benchmark results. Has two top-level keys, "context" holding the context information, and "benchmarks", holding an array with the benchmark results. @@ -127,6 +151,14 @@ def run( # if we still have no benchmarks after collection, warn. if not self.benchmarks: logger.warning(f"No benchmarks found in path/module {str(path_or_module)!r}.") + return None # TODO: Return empty result to preserve strong typing + + if isinstance(params, Params): + dparams = asdict(params) + else: + dparams = params + + _check(dparams, self.benchmarks) dcontext: dict[str, Any] = dict() @@ -143,18 +175,21 @@ def run( results: list[dict[str, Any]] = [] for benchmark in self.benchmarks: + # TODO: Refactor once benchmark contains interface + sig = inspect.signature(benchmark.fn) + bmparams = {k: v for k, v in dparams.items() if k in sig.parameters} res: dict[str, Any] = {} - # TODO: Validate against interface and pass only the kwargs relevant to the benchmark - params |= benchmark.params try: - benchmark.setUp(**params) - res.update(benchmark.fn(**params)) + benchmark.setUp(**bmparams) + # Todo: check params + res["name"] = benchmark.fn.__name__ + res["value"] = benchmark.fn(**bmparams) except Exception as e: # TODO: This needs work res["error_occurred"] = True res["error_message"] = str(e) finally: - benchmark.tearDown(**params) + benchmark.tearDown(**bmparams) results.append(res) return BenchmarkResult( diff --git a/tests/conftest.py b/tests/conftest.py index 84a0021..c67ae52 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -13,3 +13,10 @@ def testfolder() -> str: """A test directory for benchmark collection.""" return str(HERE / "testproject") + + +# TODO: Consider merging all test directories into one, +# filtering benchmarks by testcase via tags. +@pytest.fixture(scope="session") +def typecheckfolder() -> str: + return str(HERE / "typechecking") diff --git a/tests/test_argcheck.py b/tests/test_argcheck.py new file mode 100644 index 0000000..5c919e8 --- /dev/null +++ b/tests/test_argcheck.py @@ -0,0 +1,24 @@ +import os + +import pytest + +from nnbench import runner + + +def test_argcheck(typecheckfolder: str) -> None: + benchmarks = os.path.join(typecheckfolder, "benchmarks.py") + r = runner.AbstractBenchmarkRunner() + with pytest.raises(TypeError, match="expected type .*"): + r.run(benchmarks, params={"x": 1, "y": "1"}) + with pytest.raises(ValueError, match="missing value for required parameter.*"): + r.run(benchmarks, params={"x": 1}) + + r.run(benchmarks, params={"x": 1, "y": 1}) + + +def test_error_on_duplicate_params(typecheckfolder: str) -> None: + benchmarks = os.path.join(typecheckfolder, "duplicate_benchmarks.py") + + with pytest.raises(TypeError, match="got non-unique types.*"): + r = runner.AbstractBenchmarkRunner() + r.run(benchmarks, params={"x": 1, "y": 1}) diff --git a/tests/typechecking/benchmarks.py b/tests/typechecking/benchmarks.py new file mode 100644 index 0000000..3676c26 --- /dev/null +++ b/tests/typechecking/benchmarks.py @@ -0,0 +1,16 @@ +import nnbench + + +@nnbench.benchmark +def double(x: int) -> int: + return x * 2 + + +@nnbench.benchmark +def triple(y: int) -> int: + return y * 3 + + +@nnbench.benchmark +def prod(x: int, y: int) -> int: + return x * y diff --git a/tests/typechecking/duplicate_benchmarks.py b/tests/typechecking/duplicate_benchmarks.py new file mode 100644 index 0000000..4819a88 --- /dev/null +++ b/tests/typechecking/duplicate_benchmarks.py @@ -0,0 +1,11 @@ +import nnbench + + +@nnbench.benchmark +def double(x: int) -> int: + return x * 2 + + +@nnbench.benchmark +def triple_str(x: str) -> str: + return x * 3