Skip to content

Commit

Permalink
Remove artifact facility, add Thunk generic
Browse files Browse the repository at this point in the history
Also stop the practice of binding partial parametrizations directly to
benchmarks. This has the effect that we can manipulate benchmark function
parameters if need be (for example by lazy-loading thunk parameters).

Changes the interface construction slightly to inject the partial
parametrization as defaults over the `inspect.Parameter` default values.
  • Loading branch information
nicholasjng committed Mar 19, 2024
1 parent 91976ee commit b88073d
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 161 deletions.
7 changes: 2 additions & 5 deletions src/nnbench/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import sys
import types
import warnings
from functools import partial, update_wrapper
from typing import Any, Callable, Iterable, Union, get_args, get_origin, overload

from nnbench.types import Benchmark
Expand Down Expand Up @@ -173,8 +172,7 @@ def decorator(fn: Callable) -> list[Benchmark]:
)
names.add(name)

wrapper = update_wrapper(partial(fn, **params), fn)
bm = Benchmark(wrapper, name=name, setUp=setUp, tearDown=tearDown, tags=tags)
bm = Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags)
benchmarks.append(bm)
return benchmarks

Expand Down Expand Up @@ -232,8 +230,7 @@ def decorator(fn: Callable) -> list[Benchmark]:
)
names.add(name)

wrapper = update_wrapper(partial(fn, **params), fn)
bm = Benchmark(wrapper, name=name, setUp=setUp, tearDown=tearDown, tags=tags)
bm = Benchmark(fn, name=name, params=params, setUp=setUp, tearDown=tearDown, tags=tags)
benchmarks.append(bm)
return benchmarks

Expand Down
171 changes: 15 additions & 156 deletions src/nnbench/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,7 @@

import copy
import inspect
import os
import shutil
import weakref
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass, field
from pathlib import Path
from tempfile import mkdtemp
from typing import (
Any,
Callable,
Expand All @@ -21,13 +15,6 @@

from nnbench.context import Context

try:
import fsspec

HAS_FSSPEC = True
except ImportError:
HAS_FSSPEC = False

T = TypeVar("T")
Variable = tuple[str, type, Any]

Expand Down Expand Up @@ -113,143 +100,9 @@ def expand(cls, bms: list[dict[str, Any]]) -> BenchmarkRecord:
# context data.


class ArtifactLoader:
@abstractmethod
def load(self) -> os.PathLike[str]:
"""Load the artifact"""


class LocalArtifactLoader(ArtifactLoader):
"""
ArtifactLoader for loading artifacts from a local file system.
Parameters
----------
path : str | os.PathLike[str]
The file system path to the artifact.
"""

def __init__(self, path: str | os.PathLike[str]) -> None:
self._path = path

def load(self) -> Path:
"""
Returns the path to the artifact on the local file system.
"""
return Path(self._path).resolve()


class FilePathArtifactLoader(ArtifactLoader):
"""
ArtifactLoader for loading artifacts using fsspec-supported file systems.
This allows for loading from various file systems like local, S3, GCS, etc.,
by using a unified API provided by fsspec.
Parameters
----------
path : str | os.PathLike[str]
The path to the artifact, which can include a protocol specifier (like 's3://') for remote access.
destination : str | os.PathLike[str] | None
The local directory to which remote artifacts will be downloaded. If provided, the model data will be persisted. Otherwise, local artifacts are cleaned.
storage_options : dict[str, Any] | None
Storage options for remote storage.
"""

def __init__(
self,
path: str | os.PathLike[str],
destination: str | os.PathLike[str] | None = None,
storage_options: dict[str, Any] | None = None,
) -> None:
self.source_path = str(path)
if destination:
target_path = str(Path(destination).resolve())
delete = False
else:
target_path = str(Path(mkdtemp()).resolve())
delete = True
self._finalizer = weakref.finalize(
self, lambda d, t: shutil.rmtree(t) if d else None, d=delete, t=target_path
)
self.target_path = target_path
self.storage_options = storage_options or {}

def load(self) -> Path:
"""
Loads the artifact and returns the local path.
Returns
-------
Path
The path to the artifact on the local filesystem.
Raises
------
ImportError
When fsspec is not installed.
"""
if not HAS_FSSPEC:
raise ImportError(
"class {self.__class__.__name__} requires `fsspec` to be installed. You can install it by running `python -m pip install --upgrade fsspec`"
)
fs = fsspec.filesystem(fsspec.utils.get_protocol(self.source_path))
fs.get(self.source_path, self.target_path, recursive=True)
return Path(self.target_path).resolve()


class Artifact(Generic[T], metaclass=ABCMeta):
"""
A base artifact class for loading (materializing) artifacts from disk or from remote storage.
This is a helper to convey which kind of type gets loaded for a benchmark in a type-safe way.
It is most useful when running models on already saved data or models, e.g. when
comparing a newly trained model against a baseline in storage.
You need to supply an ArtifactLoader with a load() method to load the Artifact into the
local system storage.
Subclasses need to implement the `Artifact.deserialize()` API, telling nnbench to
load the desired artifact from their path.
Parameters
----------
loader: ArtifactLoader
Loader to get the artifact.
"""

def __init__(self, loader: ArtifactLoader) -> None:
# Save the path for later just-in-time deserialization.
self.path = loader.load() # fetch the artifact from wherever it resides
self._value: T | None = None

@abstractmethod
def deserialize(self) -> None:
"""Deserialize the artifact."""

def is_deserialized(self) -> bool:
"""Checks if the artifact is already deserialized."""
return self._value is not None

def __str__(self) -> str:
return f"Artifact(path={self.path!r}, is_deserialized={self.is_deserialized()})"

def __repr__(self) -> str:
return f"{self.__class__.__name__}(path={self.path!r}, is_deserialized={self.is_deserialized()})"

@property
def value(self) -> T:
"""
Returns the deserialized artifact value.
Returns
-------
T
The deserialized value of the artifact.
"""
if self._value is None:
self.deserialize()
return self._value
class Thunk(Generic[T]):
def __call__(self) -> T:
raise NotImplementedError


@dataclass(init=False, frozen=True)
Expand Down Expand Up @@ -284,11 +137,15 @@ class Benchmark:
tags: tuple[str, ...]
Additional tags to attach for bookkeeping and selective filtering during runs.
interface: Interface
Interface of the benchmark function
Interface of the benchmark function.
params: dict[str, Any]
A partial parametrization to apply to the benchmark function. Internal only,
you should not need to set this yourself.
"""

fn: Callable[..., Any]
name: str | None = field(default=None)
params: dict[str, Any] = field(default_factory=dict)
setUp: Callable[..., None] = field(repr=False, default=NoOp)
tearDown: Callable[..., None] = field(repr=False, default=NoOp)
tags: tuple[str, ...] = field(repr=False, default=())
Expand All @@ -297,7 +154,7 @@ class Benchmark:
def __post_init__(self):
if not self.name:
super().__setattr__("name", self.fn.__name__)
super().__setattr__("interface", Interface.from_callable(self.fn))
super().__setattr__("interface", Interface.from_callable(self.fn, self.params))


@dataclass(frozen=True)
Expand Down Expand Up @@ -327,18 +184,20 @@ class Interface:
returntype: type

@classmethod
def from_callable(cls, fn: Callable) -> Interface:
def from_callable(cls, fn: Callable, defaults: dict[str, Any]) -> Interface:
"""
Creates an interface instance from the given callable.
"""
# Set follow_wrapped=False to get the partially filled interfaces.
# Set `follow_wrapped=False` to get the partially filled interfaces.
# Otherwise we get missing value errors for parameters supplied in benchmark decorators.
sig = inspect.signature(fn, follow_wrapped=False)
ret = sig.return_annotation
_defaults = {k: defaults.get(k, v.default) for k, v in sig.parameters.items()}
# defaults are the signature parameters, then the partial parametrization.
return cls(
tuple(sig.parameters.keys()),
tuple(p.annotation for p in sig.parameters.values()),
tuple(p.default for p in sig.parameters.values()),
tuple((k, v.annotation, v.default) for k, v in sig.parameters.items()),
tuple(_defaults.values()),
tuple((k, v.annotation, _defaults[k]) for k, v in sig.parameters.items()),
type(ret) if ret is None else ret,
)

0 comments on commit b88073d

Please sign in to comment.