Skip to content

Commit

Permalink
Skip typecheck immediately, add thunk support
Browse files Browse the repository at this point in the history
Slightly changes parameter construction and adds a dethunking step right
before the benchmark loop.

This means that the thunk values are accessed at the latest possible
time, which is just before benchmark execution.

Moves the context construction ahead of the empty collection check, so
that we give back a constructed context even in the case of no found
benchmarks.

Adds two C++-style thunk helpers, `is_thunk` for deciding if a value is
a thunk, and `is_thunk_type` to decide if a value type is a thunk type
annotation.

The whole thunk facility is designed to work both with the
`nnbench.types.Thunk` type as well as with general anonymous functions.
  • Loading branch information
nicholasjng committed Mar 19, 2024
1 parent b88073d commit c789d72
Showing 1 changed file with 66 additions and 33 deletions.
99 changes: 66 additions & 33 deletions src/nnbench/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import collections.abc
import contextlib
import inspect
import logging
Expand All @@ -12,7 +13,7 @@
from dataclasses import asdict
from datetime import datetime
from pathlib import Path
from typing import Any, Callable, Generator, Sequence, get_origin
from typing import Any, Callable, Generator, Sequence, get_args, get_origin

from nnbench.context import Context, ContextProvider
from nnbench.types import Benchmark, BenchmarkRecord, Parameters
Expand All @@ -29,12 +30,18 @@ def isdunder(s: str) -> bool:
return s.startswith("__") and s.endswith("__")


def is_thunk(v: Any) -> bool:
return callable(v) and len(inspect.signature(v).parameters) == 0


def is_thunk_type(t: type) -> bool:
return get_origin(t) is collections.abc.Callable and get_args(t)[0] == []


def qualname(fn: Callable) -> str:
fnname = fn.__name__
fnqualname = fn.__qualname__
if fnname == fnqualname:
return fnname
return f"{fnqualname}.{fnname}"
if fn.__name__ == fn.__qualname__:
return fn.__name__
return f"{fn.__qualname__}.{fn.__name__}"


@contextlib.contextmanager
Expand Down Expand Up @@ -72,8 +79,11 @@ def __init__(self, typecheck: bool = True):
self.typecheck = typecheck

def _check(self, params: dict[str, Any]) -> None:
param_types = {k: type(v) for k, v in params.items()}
if not self.typecheck:
return

allvars: dict[str, tuple[type, Any]] = {}
required: set[str] = set()
empty = inspect.Parameter.empty

def _issubtype(t1: type, t2: type) -> bool:
Expand All @@ -89,9 +99,12 @@ def _issubtype(t1: type, t2: type) -> bool:
# TODO: Extend typing checks to args.
return issubclass(t1, t2)

# stitch together the union interface comprised of all benchmarks.
for bm in self.benchmarks:
for var in bm.interface.variables:
name, typ, default = var
if default == empty:
required.add(name)
if name in params and default != empty:
logger.debug(
f"using given value {params[name]} over default value {default} "
Expand All @@ -101,7 +114,7 @@ def _issubtype(t1: type, t2: type) -> bool:
if typ == empty:
logger.debug(f"parameter {name!r} untyped in benchmark {bm.name}().")

if name in allvars and self.typecheck:
if name in allvars:
currvar = allvars[name]
orig_type, orig_val = new_type, new_val = currvar
# If a benchmark has a variable without a default value,
Expand All @@ -126,20 +139,32 @@ def _issubtype(t1: type, t2: type) -> bool:
else:
allvars[name] = (typ, default)

for name, (typ, default) in allvars.items():
# check if a no-default variable has no parameter.
if name not in param_types and default == empty:
raise ValueError(f"missing value for required parameter {name!r}")
# check if any required variable has no parameter.
missing = required - params.keys()
if missing:
msng, *_ = missing
raise ValueError(f"missing value for required parameter {msng!r}")

for k, v in params.items():
if k not in allvars:
warnings.warn(
f"ignoring parameter {k!r} since it is not part of any benchmark interface."
)
continue

# skip the subsequent type check if the variable is untyped,
# or if typechecks are disabled.
if typ == empty or not self.typecheck:
typ, default = allvars[k]
# skip the subsequent type check if the variable is untyped.
if typ == empty:
continue

vtype = type(v)
if is_thunk(v) and not is_thunk_type(typ):
# in case of a thunk, check the result type of __call__() instead.
vtype = inspect.signature(v).return_annotation

# type-check parameter value against the narrowest hinted type.
if name in param_types and not _issubtype(param_types[name], typ):
raise TypeError(
f"expected type {typ} for parameter {name!r}, got {param_types[name]}"
)
if not _issubtype(vtype, typ):
raise TypeError(f"expected type {typ} for parameter {k!r}, got {vtype}")

def clear(self) -> None:
"""Clear all registered benchmarks."""
Expand Down Expand Up @@ -230,30 +255,38 @@ def run(
if not self.benchmarks:
self.collect(path_or_module, tags)

# if we still have no benchmarks after collection, warn and return an empty record.
if isinstance(context, Context):
ctx = context
else:
ctx = Context()
for provider in context:
ctx.add(provider)

# if we didn't find any benchmarks, warn and return an empty record.
if not self.benchmarks:
warnings.warn(f"No benchmarks found in path/module {str(path_or_module)!r}.")
return BenchmarkRecord(context=Context(), benchmarks=[])
return BenchmarkRecord(context=ctx, benchmarks=[])

params = params or {}
if isinstance(params, Parameters):
dparams = asdict(params)
else:
dparams = params
dparams = params or {}

self._check(dparams)
results: list[dict[str, Any]] = []

if isinstance(context, Context):
ctx = context
else:
ctx = Context()
for provider in context:
ctx.add(provider)
def _maybe_dethunk(v, expected_type):
if is_thunk(v) and not is_thunk_type(expected_type):
return v()
return v

results: list[dict[str, Any]] = []
for benchmark in self.benchmarks:
bmparams = {k: v for k, v in dparams.items() if k in benchmark.interface.names}
bmdefaults = {k: v for (k, _, v) in benchmark.interface.variables}
bmtypes = dict(zip(benchmark.interface.names, benchmark.interface.types))
bmparams = dict(zip(benchmark.interface.names, benchmark.interface.defaults))
# TODO: Does this need a copy.deepcopy()?
bmparams |= {k: v for k, v in dparams.items() if k in bmparams}
bmparams = {k: _maybe_dethunk(v, bmtypes[k]) for k, v in bmparams.items()}

# TODO: Wrap this into an execution context
res: dict[str, Any] = {
"name": benchmark.name,
Expand All @@ -262,7 +295,7 @@ def run(
"date": datetime.now().isoformat(timespec="seconds"),
"error_occurred": False,
"error_message": "",
"parameters": bmdefaults | bmparams,
"parameters": bmparams,
}
try:
benchmark.setUp(**bmparams)
Expand Down

0 comments on commit c789d72

Please sign in to comment.