Skip to content

Commit

Permalink
Refactor interfaces to new nnbench.types module
Browse files Browse the repository at this point in the history
This module will be advertised as the one containing (most of) the
abstractions that the user can/must implement to customize nnbench.
  • Loading branch information
nicholasjng committed Jan 18, 2024
1 parent 355f1e2 commit e37acb0
Show file tree
Hide file tree
Showing 3 changed files with 68 additions and 62 deletions.
56 changes: 2 additions & 54 deletions src/nnbench/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,67 +2,15 @@

from __future__ import annotations

from dataclasses import dataclass, field
from typing import Any, Callable, Iterable

from nnbench.types import Benchmark

def NoOp(**kwargs: Any) -> None:
pass


# TODO: Should this be frozen (since the setUp and tearDown hooks are empty returns)?
@dataclass(init=False)
class Params:
"""
A dataclass designed to hold benchmark parameters. This class is not functional
on its own, and needs to be subclassed according to your benchmarking workloads.
The main advantage over passing parameters as a dictionary is, of course,
static analysis and type safety for your benchmarking code.
"""

def NoOp(**kwargs: Any) -> None:
pass


@dataclass(frozen=True)
class Benchmark:
"""
Data model representing a benchmark. Subclass this to define your own custom benchmark.
Parameters
----------
fn: Callable[..., Any]
The function defining the benchmark.
name: str | None
A name to display for the given benchmark. If not given, will be constructed from the
function name and given parameters.
params: dict[str, Any]
Fixed parameters to pass to the benchmark.
setUp: Callable[..., None]
A setup hook run before the benchmark. Must take all members of `params` as inputs.
tearDown: Callable[..., None]
A teardown hook run after the benchmark. Must take all members of `params` as inputs.
tags: tuple[str, ...]
Additional tags to attach for bookkeeping and selective filtering during runs.
"""

fn: Callable[..., Any]
name: str | None = field(default=None)
params: dict[str, Any] = field(repr=False, default_factory=dict)
setUp: Callable[..., None] = field(repr=False, default=NoOp)
tearDown: Callable[..., None] = field(repr=False, default=NoOp)
tags: tuple[str, ...] = field(repr=False, default=())

def __post_init__(self):
if not self.name:
name = self.fn.__name__
if self.params:
name += "_" + "_".join(f"{k}={v}" for k, v in self.params.items())

super().__setattr__("name", name)
# TODO: Parse interface using `inspect`, attach to the class


def benchmark(
func: Callable[..., Any] | None = None,
params: dict[str, Any] | None = None,
Expand Down
10 changes: 2 additions & 8 deletions src/nnbench/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,12 @@
import os
import sys
from pathlib import Path
from typing import Any, Sequence, TypedDict
from typing import Any, Sequence

from nnbench.context import ContextProvider
from nnbench.core import Benchmark
from nnbench.types import Benchmark, BenchmarkResult
from nnbench.util import import_file_as_module, ismodule


class BenchmarkResult(TypedDict):
context: dict[str, Any]
benchmarks: list[dict[str, Any]]


logger = logging.getLogger(__name__)


Expand Down
64 changes: 64 additions & 0 deletions src/nnbench/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
from dataclasses import dataclass, field
from typing import Any, Callable, TypedDict


class BenchmarkResult(TypedDict):
context: dict[str, Any]
benchmarks: list[dict[str, Any]]


def NoOp(**kwargs: Any) -> None:
pass


# TODO: Should this be frozen (since the setUp and tearDown hooks are empty returns)?
@dataclass(init=False)
class Params:
"""
A dataclass designed to hold benchmark parameters. This class is not functional
on its own, and needs to be subclassed according to your benchmarking workloads.
The main advantage over passing parameters as a dictionary is, of course,
static analysis and type safety for your benchmarking code.
"""

pass


@dataclass(frozen=True)
class Benchmark:
"""
Data model representing a benchmark. Subclass this to define your own custom benchmark.
Parameters
----------
fn: Callable[..., Any]
The function defining the benchmark.
name: str | None
A name to display for the given benchmark. If not given, will be constructed from the
function name and given parameters.
params: dict[str, Any]
Fixed parameters to pass to the benchmark.
setUp: Callable[..., None]
A setup hook run before the benchmark. Must take all members of `params` as inputs.
tearDown: Callable[..., None]
A teardown hook run after the benchmark. Must take all members of `params` as inputs.
tags: tuple[str, ...]
Additional tags to attach for bookkeeping and selective filtering during runs.
"""

fn: Callable[..., Any]
name: str | None = field(default=None)
params: dict[str, Any] = field(repr=False, default_factory=dict)
setUp: Callable[..., None] = field(repr=False, default=NoOp)
tearDown: Callable[..., None] = field(repr=False, default=NoOp)
tags: tuple[str, ...] = field(repr=False, default=())

def __post_init__(self):
if not self.name:
name = self.fn.__name__
if self.params:
name += "_" + "_".join(f"{k}={v}" for k, v in self.params.items())

super().__setattr__("name", name)
# TODO: Parse interface using `inspect`, attach to the class

0 comments on commit e37acb0

Please sign in to comment.