Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Interface-checks #12

Merged
merged 6 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 46 additions & 11 deletions src/nnbench/runner.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,44 @@
"""The abstract benchmark runner interface, which can be overridden for custom benchmark workloads."""
from __future__ import annotations

import inspect
import logging
import os
import sys
from dataclasses import asdict
from pathlib import Path
from typing import Any, Sequence

from nnbench.context import ContextProvider
from nnbench.reporter import BaseReporter, reporter_registry
from nnbench.types import Benchmark, BenchmarkResult
from nnbench.types import Benchmark, BenchmarkResult, Params
from nnbench.util import import_file_as_module, ismodule

logger = logging.getLogger(__name__)


def _check(params: dict[str, Any], benchmarks: list[Benchmark]) -> None:
param_types = {k: type(v) for k, v in params.items()}
benchmark_interface: dict[str, inspect.Parameter] = {}
for bm in benchmarks:
for name, param in inspect.signature(bm.fn).parameters.items():
param_type = param.annotation
if name in benchmark_interface and benchmark_interface[name].annotation != param_type:
orig_type = benchmark_interface[name].annotation
raise TypeError(
f"got non-unique types {orig_type}, {param_type} for parameter {name!r}"
)
benchmark_interface[name] = param
for name, param in benchmark_interface.items():
if name not in param_types and param.default == inspect.Parameter.empty:
raise ValueError(f"missing value for required parameter {name!r}")
if not issubclass(param_types[name], param.annotation):
raise TypeError(
f"expected type {param.annotation} for parameter {name!r}, "
f"got {param_types[name]!r}"
)


def iscontainer(s: Any) -> bool:
return isinstance(s, (tuple, list))

Expand Down Expand Up @@ -90,11 +114,11 @@ def collect(

def run(
self,
path_or_module: str | os.PathLike[str] = "__main__",
params: dict[str, Any] | None = None,
path_or_module: str | os.PathLike[str],
params: dict[str, Any] | Params,
tags: tuple[str, ...] = (),
context: Sequence[ContextProvider] = (),
) -> BenchmarkResult:
) -> BenchmarkResult | None:
"""
Run a previously collected benchmark workload.
Expand All @@ -103,7 +127,7 @@ def run(
path_or_module: str | os.PathLike[str]
Name or path of the module to discover benchmarks in. Can also be a directory,
in which case benchmarks are collected from the Python files therein.
params: dict[str, Any] | None
params: dict[str, Any] | Params
Parameters to use for the benchmark run. Names have to match positional and keyword
argument names of the benchmark functions.
tags: tuple[str, ...]
Expand All @@ -116,7 +140,7 @@ def run(
Returns
-------
BenchmarkResult
BenchmarkResult | None
A JSON output representing the benchmark results. Has two top-level keys, "context"
holding the context information, and "benchmarks", holding an array with the
benchmark results.
Expand All @@ -127,6 +151,14 @@ def run(
# if we still have no benchmarks after collection, warn.
if not self.benchmarks:
logger.warning(f"No benchmarks found in path/module {str(path_or_module)!r}.")
return None # TODO: Return empty result to preserve strong typing

if isinstance(params, Params):
dparams = asdict(params)
else:
dparams = params

_check(dparams, self.benchmarks)

dcontext: dict[str, Any] = dict()

Expand All @@ -143,18 +175,21 @@ def run(

results: list[dict[str, Any]] = []
for benchmark in self.benchmarks:
# TODO: Refactor once benchmark contains interface
sig = inspect.signature(benchmark.fn)
bmparams = {k: v for k, v in dparams.items() if k in sig.parameters}
res: dict[str, Any] = {}
# TODO: Validate against interface and pass only the kwargs relevant to the benchmark
params |= benchmark.params
try:
benchmark.setUp(**params)
res.update(benchmark.fn(**params))
benchmark.setUp(**bmparams)
# Todo: check params
res["name"] = benchmark.fn.__name__
res["value"] = benchmark.fn(**bmparams)
except Exception as e:
# TODO: This needs work
res["error_occurred"] = True
res["error_message"] = str(e)
finally:
benchmark.tearDown(**params)
benchmark.tearDown(**bmparams)
results.append(res)

return BenchmarkResult(
Expand Down
7 changes: 7 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,10 @@
def testfolder() -> str:
"""A test directory for benchmark collection."""
return str(HERE / "testproject")


# TODO: Consider merging all test directories into one,
# filtering benchmarks by testcase via tags.
@pytest.fixture(scope="session")
def typecheckfolder() -> str:
return str(HERE / "typechecking")
24 changes: 24 additions & 0 deletions tests/test_argcheck.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import os

import pytest

from nnbench import runner


def test_argcheck(typecheckfolder: str) -> None:
benchmarks = os.path.join(typecheckfolder, "benchmarks.py")
r = runner.AbstractBenchmarkRunner()
with pytest.raises(TypeError, match="expected type <class 'int'>.*"):
r.run(benchmarks, params={"x": 1, "y": "1"})
with pytest.raises(ValueError, match="missing value for required parameter.*"):
r.run(benchmarks, params={"x": 1})

r.run(benchmarks, params={"x": 1, "y": 1})


def test_error_on_duplicate_params(typecheckfolder: str) -> None:
benchmarks = os.path.join(typecheckfolder, "duplicate_benchmarks.py")

with pytest.raises(TypeError, match="got non-unique types.*"):
r = runner.AbstractBenchmarkRunner()
r.run(benchmarks, params={"x": 1, "y": 1})
16 changes: 16 additions & 0 deletions tests/typechecking/benchmarks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
import nnbench


@nnbench.benchmark
def double(x: int) -> int:
return x * 2


@nnbench.benchmark
def triple(y: int) -> int:
return y * 3


@nnbench.benchmark
def prod(x: int, y: int) -> int:
return x * y
11 changes: 11 additions & 0 deletions tests/typechecking/duplicate_benchmarks.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
import nnbench


@nnbench.benchmark
def double(x: int) -> int:
return x * 2


@nnbench.benchmark
def triple_str(x: str) -> str:
return x * 3
Loading