diff --git a/docs/guides/artifacts.md b/docs/guides/artifacts.md deleted file mode 100644 index 780e4d6..0000000 --- a/docs/guides/artifacts.md +++ /dev/null @@ -1,57 +0,0 @@ -# Using artifacts in nnbench -With more complex benchmarking set-ups you will find yourself wanting to use static artifacts. -These can be, for example, test and validation data, or serialized model files from a model registry. -nnbench provides an artifact framework to handle these assets. -This framework consists of the `ArtifactLoader`, and the `Artifact` base classes. - -Conceptually, they are intended to be used as follows: -- `ArtifactLoader` to load the artifact onto the local filesystem, -- `Artifact` to handle the artifact within nnbench and enable lazy loading into memory, - -You can implement your own derivative classes to handle custom logic for artifact deserialization and loading. -Additionally, we provide some derived classes out of the box to handle local filepaths using filesystems, which are covered by the fsspec package. -Let us now discuss each class of the framework in detail. - -## Using the `ArtifactLoader` -The `ArtifactLoader` is an abstract base class for which you can implement your custom instance by overriding the `load()` method, which needs to return a file path either as string or a path-like object. -You can see an example of it in the `LocalArtifactLoader` implementation that is also provided out of the box by nnbench. -```python -import os -from pathlib import Path -from nnbench.types import ArtifactLoader - -class LocalArtifactLoader(ArtifactLoader): - def __init__(self, path: str | os.PathLike[str]) -> None: - self._path = path - - def load(self) -> Path: - return Path(self._path).resolve() -``` - -The use of these `ArtifactLoader` becomes apparent when you think about using it for remote artifact storage locations such as an S3 bucket. -Besides the barebones `LocalArtifactLoader`, nnbench also provides the `nnbench.types.FilePathArtifactLoader`. -To use it you have to `pip install fsspec` as an additional dependency. -The class is then able to handle different filepaths, like S3 and GCS URIs. - -## Using the `Artifact` class -The main purpose of the `Artifact` class is to load (deserialize) artifacts in a type-safe way, enabling autocompletion and type inference for your IDE to improve your developer experience. -When implementing a custom `Artifact` subclass, you only have to override the `deserialize` method, which assigns the loaded object(s) to the `Artifact._value` member. -You can use the `self.path` attribute to access the local filepath to the serialized artifact. This is provided by the `.load()` method of the appropriate `ArtifactLoader` that you have to pass upon instantiation. -The artifact then exposes the wrapped object with the `.value` property, which returns the value of the internal `self._value` class member. -Artifacts are loaded lazily. -If you want to deserialize an artifact at a specific point, you can do so by calling the implemented `deserialize()` method. -Otherwise, `deserialize()` is called internally when you first access the value. -To provide a minimal example, here is how you could implement an `Artifact` for loading `numpy` arrays. - -```python -import numpy as np -from nnbench.types import Artifact -from loaders import LocalArtifactLoader - -class NumpyArtifact(Artifact): - def deserialize(self) -> None: - self._value = np.load(self.path) - -array_artifact = NumpyArtifact(LocalArtifactLoader('path/to/array')) -print(array_artifact.value) -``` diff --git a/docs/guides/memoization.md b/docs/guides/memoization.md new file mode 100644 index 0000000..1c04bbd --- /dev/null +++ b/docs/guides/memoization.md @@ -0,0 +1,86 @@ +# Using memoization for memory efficient benchmarks + +In machine learning workloads, models and datasets of greater-than-memory size are frequently encountered. +Especially when loading and benchmarking several models in succession, for example with a parametrization, available memory can quickly become a bottleneck. + +To address this problem, this guide introduces **memos** as a way to reduce memory pressure when benchmarking multiple memory-intensive models and datasets sequentially. + +## Using the `nnbench.Memo` class +// TODO: Move the class up into the main export. + +The key to efficient memory utilization in nnbench is _memoization_. +nnbench itself provides the `nnbench.Memo` class, a generic base class that can be subclassed to yield a value and cache it for subsequent invocations. + +To subclass a memo, overload the `Memo.__call__()` operator like so: + +```python +import numpy as np +import nnbench + +class MyType: + """Contains a huge array, similarly to a model.""" + a: np.ndarray = np.zeros((10000, 10000)) + +class MyMemo(nnbench.Memo[MyType]): + + # TODO: Add a memo cache decorator. + def __call__(self) -> MyType: + return MyType() +``` + +`nnbench.Memo` objects do not take any arguments, meaning that all external state necessary to compute the value needs to be passed in the `Memo.__init__()` function. +In this way, nnbench's memos work similarly to e.g. [React's useMemo hook](https://react.dev/reference/react/useMemo). + +!!! Warning + You must explicitly hint the returned type in the `Memo.__call__()` annotation, which needs to match the generic type specialization (the type in the square brackets in the class definition), + otherwise nnbench will throw errors when validating benchmark parameters. + +## Supplying memoized values to benchmarks + +Memoization is especially useful when parametrizing benchmarks over models and datasets. + +Suppose we have a `Model` class wrapping a large (in the order of available memory) NumPy array. +If we cannot load all models into memory at the same time in a benchmark run, we can load the (serialized) models one by one using nnbench memos. + +```python +import nnbench + +import numpy as np + + +class Model: + def __init__(self, arr: np.ndarray): + self.array = arr + + def apply(self, arr: np.ndarray) -> np.ndarray: + return self.array @ arr + +class ModelMemo(nnbench.Memo[Model]): + def __init__(self, path): + self.path = path + + # TODO: Add a memo cache decorator. + def __call__(self) -> Model: + arr = np.load(self.path) + return Model(arr) + + +@nnbench.product( + model=[ModelMemo(p) for p in ("model1.npz", "model2.npz", "model3.npz")] +) +def accuracy(model: Model, data: np.ndarray) -> float: + return np.sum(model.apply(data)) +``` + +//TODO: Add tearDown task clearing the memo cache. + +After each benchmark, each model memo's corresponding value is evicted from nnbench's memoization cache. + +!!! Warning + If you evict a value before its last use in a benchmark, it will be recomputed, potentially slowing down benchmark execution by a lot. + +## Summary + +- Use `nnbench.Memo`s to lazy-load and explicitly control the lifetime of objects with large memory footprint. +- Annotate memos with their specialized type to avoid problems with nnbench's type checking and parameter validation. +- Use teardown tasks after benchmarks to evict memoized values from the memo cache. diff --git a/mkdocs.yml b/mkdocs.yml index eb7630e..0e29be4 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -22,7 +22,7 @@ nav: - guides/customization.md - guides/organization.md - guides/runners.md - - guides/artifacts.md + - guides/memoization.md - guides/transforms.md - Examples: - tutorials/index.md diff --git a/src/nnbench/core.py b/src/nnbench/core.py index 6c0cc28..31bb8ac 100644 --- a/src/nnbench/core.py +++ b/src/nnbench/core.py @@ -10,6 +10,7 @@ from typing import Any, Callable, Iterable, Union, get_args, get_origin, overload from nnbench.types import Benchmark +from nnbench.types.types import NoOp from nnbench.types.util import is_memo, is_memo_type @@ -52,10 +53,6 @@ def _default_namegen(fn: Callable, **kwargs: Any) -> str: return fn.__name__ + "_" + "_".join(f"{k}={v}" for k, v in kwargs.items()) -def NoOp(**kwargs: Any) -> None: - pass - - # Overloads for the ``benchmark`` decorator. # Case #1: Bare application without parentheses # @nnbench.benchmark @@ -178,7 +175,14 @@ def decorator(fn: Callable) -> list[Benchmark]: ) names.add(name) - bm = Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags) + bm = Benchmark( + fn, + name=name, + params=params, + setUp=setUp, + tearDown=tearDown, + tags=tags, + ) benchmarks.append(bm) return benchmarks @@ -236,7 +240,14 @@ def decorator(fn: Callable) -> list[Benchmark]: ) names.add(name) - bm = Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags) + bm = Benchmark( + fn, + name=name, + params=params, + setUp=setUp, + tearDown=tearDown, + tags=tags, + ) benchmarks.append(bm) return benchmarks diff --git a/src/nnbench/runner.py b/src/nnbench/runner.py index 089f38b..442b9c6 100644 --- a/src/nnbench/runner.py +++ b/src/nnbench/runner.py @@ -2,6 +2,7 @@ from __future__ import annotations +import collections import contextlib import inspect import logging @@ -15,7 +16,7 @@ from typing import Any, Callable, Generator, Sequence, get_origin from nnbench.context import Context, ContextProvider -from nnbench.types import Benchmark, BenchmarkRecord, Parameters +from nnbench.types import Benchmark, BenchmarkRecord, Parameters, State from nnbench.types.util import is_memo, is_memo_type from nnbench.util import import_file_as_module, ismodule @@ -247,6 +248,9 @@ def run( if not self.benchmarks: self.collect(path_or_module, tags) + family_sizes: dict[str, Any] = collections.defaultdict(int) + family_indices: dict[str, Any] = collections.defaultdict(int) + if isinstance(context, Context): ctx = context else: @@ -259,6 +263,9 @@ def run( warnings.warn(f"No benchmarks found in path/module {str(path_or_module)!r}.") return BenchmarkRecord(context=ctx, benchmarks=[]) + for bm in self.benchmarks: + family_sizes[bm.fn.__name__] += 1 + if isinstance(params, Parameters): dparams = asdict(params) else: @@ -274,6 +281,14 @@ def _maybe_dememo(v, expected_type): return v for benchmark in self.benchmarks: + bm_family = benchmark.fn.__name__ + state = State( + name=benchmark.name, + family=bm_family, + family_size=family_sizes[bm_family], + family_index=family_indices[bm_family], + ) + family_indices[bm_family] += 1 bmtypes = dict(zip(benchmark.interface.names, benchmark.interface.types)) bmparams = dict(zip(benchmark.interface.names, benchmark.interface.defaults)) # TODO: Does this need a copy.deepcopy()? @@ -291,14 +306,14 @@ def _maybe_dememo(v, expected_type): "parameters": bmparams, } try: - benchmark.setUp(**bmparams) + benchmark.setUp(state, bmparams) with timer(res): res["value"] = benchmark.fn(**bmparams) except Exception as e: res["error_occurred"] = True res["error_message"] = str(e) finally: - benchmark.tearDown(**bmparams) + benchmark.tearDown(state, bmparams) results.append(res) return BenchmarkRecord( diff --git a/src/nnbench/types/__init__.py b/src/nnbench/types/__init__.py index c7be11a..90c64a3 100644 --- a/src/nnbench/types/__init__.py +++ b/src/nnbench/types/__init__.py @@ -1 +1 @@ -from .types import Benchmark, BenchmarkRecord, Memo, Parameters +from .types import Benchmark, BenchmarkRecord, Memo, Parameters, State diff --git a/src/nnbench/types/types.py b/src/nnbench/types/types.py index 0d659f4..8b74eaf 100644 --- a/src/nnbench/types/types.py +++ b/src/nnbench/types/types.py @@ -8,7 +8,8 @@ import logging import threading from dataclasses import dataclass, field -from typing import Any, Callable, Generic, Literal, TypeVar +from types import MappingProxyType +from typing import Any, Callable, Generic, Literal, Mapping, Protocol, TypeVar from nnbench.context import Context @@ -52,10 +53,14 @@ def wrapper(self, *args, **kwargs): return wrapper -def NoOp(**kwargs: Any) -> None: +def NoOp(state: State, params: Mapping[str, Any] = MappingProxyType({})) -> None: pass +class CallbackProtocol(Protocol): + def __call__(self, state: State, params: Mapping[str, Any]) -> None: ... + + @dataclass(frozen=True) class BenchmarkRecord: context: Context @@ -133,6 +138,14 @@ def expand(cls, bms: list[dict[str, Any]]) -> BenchmarkRecord: # context data. +@dataclass(frozen=True) +class State: + name: str + family: str + family_size: int + family_index: int + + class Memo(Generic[T]): """Abstract base class for memoized values in benchmark runs.""" @@ -190,10 +203,12 @@ class Benchmark: """ fn: Callable[..., Any] - name: str | None = field(default=None) + name: str = field(default="") params: dict[str, Any] = field(default_factory=dict) - setUp: Callable[..., None] = field(repr=False, default=NoOp) - tearDown: Callable[..., None] = field(repr=False, default=NoOp) + setUp: Callable[[State, Mapping[str, Any]], None] = field(repr=False, default=NoOp) + tearDown: Callable[[State, Mapping[str, Any]], None] = field( + repr=False, default=NoOp + ) tags: tuple[str, ...] = field(repr=False, default=()) interface: Interface = field(init=False, repr=False)