diff --git a/src/nnbench/types.py b/src/nnbench/types.py index b930089..4a92656 100644 --- a/src/nnbench/types.py +++ b/src/nnbench/types.py @@ -1,8 +1,11 @@ """Useful type interfaces to override/subclass in benchmarking workflows.""" from __future__ import annotations +import os from dataclasses import dataclass, field -from typing import Any, Callable, TypedDict +from typing import Any, Callable, Generic, TypedDict, TypeVar + +T = TypeVar("T") class BenchmarkResult(TypedDict): @@ -14,6 +17,42 @@ def NoOp(**kwargs: Any) -> None: pass +class Artifact(Generic[T]): + """ + A base artifact class for loading (materializing) artifacts from disk or from remote storage. + + This is a helper to convey which kind of type gets loaded for a benchmark in a type-safe way. + It is most useful when running models on already saved data or models, e.g. when + comparing a newly trained model against a baseline in storage. + + Subclasses need to implement the `Artifact.materialize()` API, telling nnbench how to + load the desired artifact from a path. + + Parameters + ---------- + path: str | os.PathLike[str] + Path to the artifact files. + """ + + def __init__(self, path: str | os.PathLike[str]) -> None: + # Save the path for later just-in-time materialization. + self.path = path + self._value: T | None = None + + @classmethod + def materialize(cls) -> "Artifact": + """Load the artifact from storage.""" + raise NotImplementedError + + def value(self) -> T: + if self._value is None: + raise ValueError( + f"artifact has not been instantiated yet, " + f"perhaps you forgot to call {self.__class__.__name__}.materialize()?" + ) + return self._value + + # TODO: Should this be frozen (since the setUp and tearDown hooks are empty returns)? @dataclass(init=False) class Params: