Skip to content

Commit

Permalink
Merge pull request #10 from aai-institute/add-reporter-entrypoint
Browse files Browse the repository at this point in the history
Add reporter entrypoint
  • Loading branch information
nicholasjng authored Jan 22, 2024
2 parents bf1d968 + a2c1804 commit e2d4e6b
Show file tree
Hide file tree
Showing 3 changed files with 47 additions and 4 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ repos:
types_or: [ python, pyi ]
args: [--ignore-missing-imports, --scripts-are-modules]
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.1.13
rev: v0.1.14
hooks:
- id: ruff
args: [ --fix, --exit-non-zero-on-fix ]
Expand All @@ -29,6 +29,6 @@ repos:
args: [-c, pyproject.toml]
additional_dependencies: ["bandit[toml]"]
- repo: https://github.com/jsh9/pydoclint
rev: 0.3.8
rev: 0.3.9
hooks:
- id: pydoclint
20 changes: 18 additions & 2 deletions src/nnbench/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
"""A framework for organizing and running benchmark workloads on machine learning models."""

from importlib.metadata import PackageNotFoundError, version
from importlib.metadata import PackageNotFoundError, entry_points, version

try:
__version__ = version("nnbench")
Expand All @@ -10,5 +10,21 @@

# TODO: This naming is unfortunate
from .core import benchmark, parametrize
from .reporter import BaseReporter
from .reporter import BaseReporter, register_reporter
from .types import Benchmark, Params


def add_reporters():
eps = entry_points()

if hasattr(eps, "select"): # Python 3.10+ / importlib.metadata >= 3.9.0
reporters = eps.select(group="nnbench.reporters")
else:
reporters = eps.get("nnbench.reporters", []) # type: ignore

for rep in reporters:
key, clsname = rep.name.split("=", 1)
register_reporter(key, clsname)


add_reporters()
27 changes: 27 additions & 0 deletions src/nnbench/reporter.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
"""
A lightweight interface for refining, displaying, and streaming benchmark results to various sinks.
"""
from __future__ import annotations

import importlib
import sys
import types
from typing import Any
Expand Down Expand Up @@ -58,3 +60,28 @@ def report(self, result: BenchmarkResult) -> None:
reporter_registry: types.MappingProxyType[str, type[BaseReporter]] = types.MappingProxyType(
_reporter_registry
)


def register_reporter(key: str, cls_or_name: str | type[BaseReporter]) -> None:
"""
Register a reporter class by its fully qualified module path.
Parameters
----------
key: str
The key to register the reporter under. Subsequently, this key can be used in place
of reporter classes in code.
cls_or_name: str | type[BaseReporter]
Name of or full module path to the reporter class. For example, when registering a class
``MyReporter`` located in ``my_module``, ``name`` should be ``my_module.MyReporter``.
"""

if isinstance(cls_or_name, str):
name = cls_or_name
modname, clsname = name.rsplit(".", 1)
mod = importlib.import_module(modname)
cls = getattr(mod, clsname)
_reporter_registry[key] = cls
else:
# name = cls_or_name.__module__ + "." + cls_or_name.__qualname__
_reporter_registry[key] = cls_or_name

0 comments on commit e2d4e6b

Please sign in to comment.