Skip to content

Commit

Permalink
Merge branch 'main' into memo
Browse files Browse the repository at this point in the history
  • Loading branch information
maxmynter committed Mar 26, 2024
2 parents bde6455 + 20e9fe8 commit 9449cb0
Show file tree
Hide file tree
Showing 7 changed files with 143 additions and 73 deletions.
57 changes: 0 additions & 57 deletions docs/guides/artifacts.md

This file was deleted.

86 changes: 86 additions & 0 deletions docs/guides/memoization.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
# Using memoization for memory efficient benchmarks

In machine learning workloads, models and datasets of greater-than-memory size are frequently encountered.
Especially when loading and benchmarking several models in succession, for example with a parametrization, available memory can quickly become a bottleneck.

To address this problem, this guide introduces **memos** as a way to reduce memory pressure when benchmarking multiple memory-intensive models and datasets sequentially.

## Using the `nnbench.Memo` class
// TODO: Move the class up into the main export.

The key to efficient memory utilization in nnbench is _memoization_.
nnbench itself provides the `nnbench.Memo` class, a generic base class that can be subclassed to yield a value and cache it for subsequent invocations.

To subclass a memo, overload the `Memo.__call__()` operator like so:

```python
import numpy as np
import nnbench

class MyType:
"""Contains a huge array, similarly to a model."""
a: np.ndarray = np.zeros((10000, 10000))

class MyMemo(nnbench.Memo[MyType]):

# TODO: Add a memo cache decorator.
def __call__(self) -> MyType:
return MyType()
```

`nnbench.Memo` objects do not take any arguments, meaning that all external state necessary to compute the value needs to be passed in the `Memo.__init__()` function.
In this way, nnbench's memos work similarly to e.g. [React's useMemo hook](https://react.dev/reference/react/useMemo).

!!! Warning
You must explicitly hint the returned type in the `Memo.__call__()` annotation, which needs to match the generic type specialization (the type in the square brackets in the class definition),
otherwise nnbench will throw errors when validating benchmark parameters.

## Supplying memoized values to benchmarks

Memoization is especially useful when parametrizing benchmarks over models and datasets.

Suppose we have a `Model` class wrapping a large (in the order of available memory) NumPy array.
If we cannot load all models into memory at the same time in a benchmark run, we can load the (serialized) models one by one using nnbench memos.

```python
import nnbench

import numpy as np


class Model:
def __init__(self, arr: np.ndarray):
self.array = arr

def apply(self, arr: np.ndarray) -> np.ndarray:
return self.array @ arr

class ModelMemo(nnbench.Memo[Model]):
def __init__(self, path):
self.path = path

# TODO: Add a memo cache decorator.
def __call__(self) -> Model:
arr = np.load(self.path)
return Model(arr)


@nnbench.product(
model=[ModelMemo(p) for p in ("model1.npz", "model2.npz", "model3.npz")]
)
def accuracy(model: Model, data: np.ndarray) -> float:
return np.sum(model.apply(data))
```

//TODO: Add tearDown task clearing the memo cache.

After each benchmark, each model memo's corresponding value is evicted from nnbench's memoization cache.

!!! Warning
If you evict a value before its last use in a benchmark, it will be recomputed, potentially slowing down benchmark execution by a lot.

## Summary

- Use `nnbench.Memo`s to lazy-load and explicitly control the lifetime of objects with large memory footprint.
- Annotate memos with their specialized type to avoid problems with nnbench's type checking and parameter validation.
- Use teardown tasks after benchmarks to evict memoized values from the memo cache.
2 changes: 1 addition & 1 deletion mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ nav:
- guides/customization.md
- guides/organization.md
- guides/runners.md
- guides/artifacts.md
- guides/memoization.md
- guides/transforms.md
- Examples:
- tutorials/index.md
Expand Down
23 changes: 17 additions & 6 deletions src/nnbench/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from typing import Any, Callable, Iterable, Union, get_args, get_origin, overload

from nnbench.types import Benchmark
from nnbench.types.types import NoOp
from nnbench.types.util import is_memo, is_memo_type


Expand Down Expand Up @@ -52,10 +53,6 @@ def _default_namegen(fn: Callable, **kwargs: Any) -> str:
return fn.__name__ + "_" + "_".join(f"{k}={v}" for k, v in kwargs.items())


def NoOp(**kwargs: Any) -> None:
pass


# Overloads for the ``benchmark`` decorator.
# Case #1: Bare application without parentheses
# @nnbench.benchmark
Expand Down Expand Up @@ -178,7 +175,14 @@ def decorator(fn: Callable) -> list[Benchmark]:
)
names.add(name)

bm = Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags)
bm = Benchmark(
fn,
name=name,
params=params,
setUp=setUp,
tearDown=tearDown,
tags=tags,
)
benchmarks.append(bm)
return benchmarks

Expand Down Expand Up @@ -236,7 +240,14 @@ def decorator(fn: Callable) -> list[Benchmark]:
)
names.add(name)

bm = Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags)
bm = Benchmark(
fn,
name=name,
params=params,
setUp=setUp,
tearDown=tearDown,
tags=tags,
)
benchmarks.append(bm)
return benchmarks

Expand Down
21 changes: 18 additions & 3 deletions src/nnbench/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from __future__ import annotations

import collections
import contextlib
import inspect
import logging
Expand All @@ -15,7 +16,7 @@
from typing import Any, Callable, Generator, Sequence, get_origin

from nnbench.context import Context, ContextProvider
from nnbench.types import Benchmark, BenchmarkRecord, Parameters
from nnbench.types import Benchmark, BenchmarkRecord, Parameters, State
from nnbench.types.util import is_memo, is_memo_type
from nnbench.util import import_file_as_module, ismodule

Expand Down Expand Up @@ -247,6 +248,9 @@ def run(
if not self.benchmarks:
self.collect(path_or_module, tags)

family_sizes: dict[str, Any] = collections.defaultdict(int)
family_indices: dict[str, Any] = collections.defaultdict(int)

if isinstance(context, Context):
ctx = context
else:
Expand All @@ -259,6 +263,9 @@ def run(
warnings.warn(f"No benchmarks found in path/module {str(path_or_module)!r}.")
return BenchmarkRecord(context=ctx, benchmarks=[])

for bm in self.benchmarks:
family_sizes[bm.fn.__name__] += 1

if isinstance(params, Parameters):
dparams = asdict(params)
else:
Expand All @@ -274,6 +281,14 @@ def _maybe_dememo(v, expected_type):
return v

for benchmark in self.benchmarks:
bm_family = benchmark.fn.__name__
state = State(
name=benchmark.name,
family=bm_family,
family_size=family_sizes[bm_family],
family_index=family_indices[bm_family],
)
family_indices[bm_family] += 1
bmtypes = dict(zip(benchmark.interface.names, benchmark.interface.types))
bmparams = dict(zip(benchmark.interface.names, benchmark.interface.defaults))
# TODO: Does this need a copy.deepcopy()?
Expand All @@ -291,14 +306,14 @@ def _maybe_dememo(v, expected_type):
"parameters": bmparams,
}
try:
benchmark.setUp(**bmparams)
benchmark.setUp(state, bmparams)
with timer(res):
res["value"] = benchmark.fn(**bmparams)
except Exception as e:
res["error_occurred"] = True
res["error_message"] = str(e)
finally:
benchmark.tearDown(**bmparams)
benchmark.tearDown(state, bmparams)
results.append(res)

return BenchmarkRecord(
Expand Down
2 changes: 1 addition & 1 deletion src/nnbench/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
from .types import Benchmark, BenchmarkRecord, Memo, Parameters
from .types import Benchmark, BenchmarkRecord, Memo, Parameters, State
25 changes: 20 additions & 5 deletions src/nnbench/types/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@
import logging
import threading
from dataclasses import dataclass, field
from typing import Any, Callable, Generic, Literal, TypeVar
from types import MappingProxyType
from typing import Any, Callable, Generic, Literal, Mapping, Protocol, TypeVar

from nnbench.context import Context

Expand Down Expand Up @@ -52,10 +53,14 @@ def wrapper(self, *args, **kwargs):
return wrapper


def NoOp(**kwargs: Any) -> None:
def NoOp(state: State, params: Mapping[str, Any] = MappingProxyType({})) -> None:
pass


class CallbackProtocol(Protocol):
def __call__(self, state: State, params: Mapping[str, Any]) -> None: ...


@dataclass(frozen=True)
class BenchmarkRecord:
context: Context
Expand Down Expand Up @@ -133,6 +138,14 @@ def expand(cls, bms: list[dict[str, Any]]) -> BenchmarkRecord:
# context data.


@dataclass(frozen=True)
class State:
name: str
family: str
family_size: int
family_index: int


class Memo(Generic[T]):
"""Abstract base class for memoized values in benchmark runs."""

Expand Down Expand Up @@ -190,10 +203,12 @@ class Benchmark:
"""

fn: Callable[..., Any]
name: str | None = field(default=None)
name: str = field(default="")
params: dict[str, Any] = field(default_factory=dict)
setUp: Callable[..., None] = field(repr=False, default=NoOp)
tearDown: Callable[..., None] = field(repr=False, default=NoOp)
setUp: Callable[[State, Mapping[str, Any]], None] = field(repr=False, default=NoOp)
tearDown: Callable[[State, Mapping[str, Any]], None] = field(
repr=False, default=NoOp
)
tags: tuple[str, ...] = field(repr=False, default=())
interface: Interface = field(init=False, repr=False)

Expand Down

0 comments on commit 9449cb0

Please sign in to comment.