Skip to content

Commit

Permalink
Add Interface-checks (#12)
Browse files Browse the repository at this point in the history
* Implement interface check

* Implement signature check

* Add tests for parameter checks

---------

Co-authored-by: Nicholas Junge <[email protected]>
  • Loading branch information
maxmynter and nicholasjng authored Jan 23, 2024
1 parent e2d4e6b commit f17beab
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 11 deletions.
57 changes: 46 additions & 11 deletions src/nnbench/runner.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,44 @@
"""The abstract benchmark runner interface, which can be overridden for custom benchmark workloads."""
from __future__ import annotations

import inspect
import logging
import os
import sys
from dataclasses import asdict
from pathlib import Path
from typing import Any, Sequence

from nnbench.context import ContextProvider
from nnbench.reporter import BaseReporter, reporter_registry
from nnbench.types import Benchmark, BenchmarkResult
from nnbench.types import Benchmark, BenchmarkResult, Params
from nnbench.util import import_file_as_module, ismodule

logger = logging.getLogger(__name__)


def _check(params: dict[str, Any], benchmarks: list[Benchmark]) -> None:
param_types = {k: type(v) for k, v in params.items()}
benchmark_interface: dict[str, inspect.Parameter] = {}
for bm in benchmarks:
for name, param in inspect.signature(bm.fn).parameters.items():
param_type = param.annotation
if name in benchmark_interface and benchmark_interface[name].annotation != param_type:
orig_type = benchmark_interface[name].annotation
raise TypeError(
f"got non-unique types {orig_type}, {param_type} for parameter {name!r}"
)
benchmark_interface[name] = param
for name, param in benchmark_interface.items():
if name not in param_types and param.default == inspect.Parameter.empty:
raise ValueError(f"missing value for required parameter {name!r}")
if not issubclass(param_types[name], param.annotation):
raise TypeError(
f"expected type {param.annotation} for parameter {name!r}, "
f"got {param_types[name]!r}"
)


def iscontainer(s: Any) -> bool:
return isinstance(s, (tuple, list))

Expand Down Expand Up @@ -90,11 +114,11 @@ def collect(

def run(
self,
path_or_module: str | os.PathLike[str] = "__main__",
params: dict[str, Any] | None = None,
path_or_module: str | os.PathLike[str],
params: dict[str, Any] | Params,
tags: tuple[str, ...] = (),
context: Sequence[ContextProvider] = (),
) -> BenchmarkResult:
) -> BenchmarkResult | None:
"""
Run a previously collected benchmark workload.
Expand All @@ -103,7 +127,7 @@ def run(
path_or_module: str | os.PathLike[str]
Name or path of the module to discover benchmarks in. Can also be a directory,
in which case benchmarks are collected from the Python files therein.
params: dict[str, Any] | None
params: dict[str, Any] | Params
Parameters to use for the benchmark run. Names have to match positional and keyword
argument names of the benchmark functions.
tags: tuple[str, ...]
Expand All @@ -116,7 +140,7 @@ def run(
Returns
-------
BenchmarkResult
BenchmarkResult | None
A JSON output representing the benchmark results. Has two top-level keys, "context"
holding the context information, and "benchmarks", holding an array with the
benchmark results.
Expand All @@ -127,6 +151,14 @@ def run(
# if we still have no benchmarks after collection, warn.
if not self.benchmarks:
logger.warning(f"No benchmarks found in path/module {str(path_or_module)!r}.")
return None # TODO: Return empty result to preserve strong typing

if isinstance(params, Params):
dparams = asdict(params)
else:
dparams = params

_check(dparams, self.benchmarks)

dcontext: dict[str, Any] = dict()

Expand All @@ -143,18 +175,21 @@ def run(

results: list[dict[str, Any]] = []
for benchmark in self.benchmarks:
# TODO: Refactor once benchmark contains interface
sig = inspect.signature(benchmark.fn)
bmparams = {k: v for k, v in dparams.items() if k in sig.parameters}
res: dict[str, Any] = {}
# TODO: Validate against interface and pass only the kwargs relevant to the benchmark
params |= benchmark.params
try:
benchmark.setUp(**params)
res.update(benchmark.fn(**params))
benchmark.setUp(**bmparams)
# Todo: check params
res["name"] = benchmark.fn.__name__
res["value"] = benchmark.fn(**bmparams)
except Exception as e:
# TODO: This needs work
res["error_occurred"] = True
res["error_message"] = str(e)
finally:
benchmark.tearDown(**params)
benchmark.tearDown(**bmparams)
results.append(res)

return BenchmarkResult(
Expand Down
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,10 @@
def testfolder() -> str:
"""A test directory for benchmark collection."""
return str(HERE / "testproject")


# TODO: Consider merging all test directories into one,
# filtering benchmarks by testcase via tags.
@pytest.fixture(scope="session")
def typecheckfolder() -> str:
return str(HERE / "typechecking")
24 changes: 24 additions & 0 deletions tests/test_argcheck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os

import pytest

from nnbench import runner


def test_argcheck(typecheckfolder: str) -> None:
benchmarks = os.path.join(typecheckfolder, "benchmarks.py")
r = runner.AbstractBenchmarkRunner()
with pytest.raises(TypeError, match="expected type <class 'int'>.*"):
r.run(benchmarks, params={"x": 1, "y": "1"})
with pytest.raises(ValueError, match="missing value for required parameter.*"):
r.run(benchmarks, params={"x": 1})

r.run(benchmarks, params={"x": 1, "y": 1})


def test_error_on_duplicate_params(typecheckfolder: str) -> None:
benchmarks = os.path.join(typecheckfolder, "duplicate_benchmarks.py")

with pytest.raises(TypeError, match="got non-unique types.*"):
r = runner.AbstractBenchmarkRunner()
r.run(benchmarks, params={"x": 1, "y": 1})
16 changes: 16 additions & 0 deletions tests/typechecking/benchmarks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import nnbench


@nnbench.benchmark
def double(x: int) -> int:
return x * 2


@nnbench.benchmark
def triple(y: int) -> int:
return y * 3


@nnbench.benchmark
def prod(x: int, y: int) -> int:
return x * y
11 changes: 11 additions & 0 deletions tests/typechecking/duplicate_benchmarks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import nnbench


@nnbench.benchmark
def double(x: int) -> int:
return x * 2


@nnbench.benchmark
def triple_str(x: str) -> str:
return x * 3

0 comments on commit f17beab

Please sign in to comment.