Skip to content

Commit

Permalink
Feature: Thunks as an alternative to artifacts (#120)
Browse files Browse the repository at this point in the history
* Add transform submodule, parameter compression transform

This is the better way of compressing parameters compared to directly in
the benchmark runner, which steals responsibility of the transform that
we just introduced.

Refactors `nnbench.io.transform->nnbench.transforms`, the latter being
its own submodule. This is useful to have when adding new builtin
transforms, so that they do not have to go into a single file.

* Add `typecheck` flag to benchmark runner to disable typechecks

Adds two conditional branches to disable typechecks in the _check()
method.

This is nice to have when prototyping new features and inputs to
benchmarks might not exactly be of the requested types.

* Remove artifact facility, add Thunk generic

Also stop the practice of binding partial parametrizations directly to
benchmarks. This has the effect that we can manipulate benchmark function
parameters if need be (for example by lazy-loading thunk parameters).

Changes the interface construction slightly to inject the partial
parametrization as defaults over the `inspect.Parameter` default values.

* Skip typecheck immediately, add thunk support

Slightly changes parameter construction and adds a dethunking step right
before the benchmark loop.

This means that the thunk values are accessed at the latest possible
time, which is just before benchmark execution.

Moves the context construction ahead of the empty collection check, so
that we give back a constructed context even in the case of no found
benchmarks.

Adds two C++-style thunk helpers, `is_thunk` for deciding if a value is
a thunk, and `is_thunk_type` to decide if a value type is a thunk type
annotation.

The whole thunk facility is designed to work both with the
`nnbench.types.Thunk` type as well as with general anonymous functions.

* Change thunk -> memo, add type check bypass for memos

In the current setup, properly typed memos and callables pass the type
checker.

Factors out the types into their own submodule, to be refactored later
into their biggest constituents.

* Migrate artifact benchmarking code to memo syntax

Showcases memo subclassing, parametrization, and trivializes the run()
command again.

As a downside, only per-class benchmarks and aggregates can be run in
a single run, not side-by-side (that would require `params` injection).

* Migrate tests to new params assumptions

Partial parametrizations are not bound eagerly to the benchmark functions
anymore, which makes it simpler to inject memos and de-memoize variables
just in time for execution.

What is left is validation that a subsequent benchmark of models with
intermittent garbage collection actually reaps each model after the
benchmark is done.
  • Loading branch information
nicholasjng authored Mar 21, 2024
1 parent d09cd3c commit 65fc45b
Show file tree
Hide file tree
Showing 11 changed files with 303 additions and 346 deletions.
1 change: 0 additions & 1 deletion examples/artifact_benchmarking/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ mpmath==1.3.0
multidict==6.0.5
multiprocess==0.70.16
networkx==3.2.1
nnbench @ git+https://github.com/aai-institute/nnbench.git
numpy==1.26.4
packaging==23.2
pandas==2.2.1
Expand Down
199 changes: 163 additions & 36 deletions examples/artifact_benchmarking/src/benchmark.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,100 @@
import os
import tempfile
import time
from functools import lru_cache, partial
from functools import cache, lru_cache, partial
from typing import Sequence

import torch
from datasets import Dataset, load_dataset
from torch.nn import Module
from torch.utils.data import DataLoader
from training.training import tokenize_and_align_labels
from transformers import (
AutoModelForTokenClassification,
AutoTokenizer,
DataCollatorForTokenClassification,
)
from transformers.tokenization_utils_base import PreTrainedTokenizerBase

import nnbench
from nnbench.types import Memo


class TokenClassificationModelMemo(Memo[Module]):
def __init__(self, path: str):
self.path = path

@cache
def __call__(self) -> Module:
model: Module = AutoModelForTokenClassification.from_pretrained(
self.path, use_safetensors=True
)
return model

def __str__(self):
return self.path


class TokenizerMemo(Memo[PreTrainedTokenizerBase]):
def __init__(self, path: str):
self.path = path

@cache
def __call__(self) -> PreTrainedTokenizerBase:
tokenizer: PreTrainedTokenizerBase = AutoTokenizer.from_pretrained(self.path)
return tokenizer

def __str__(self):
return self.path


class ConllValidationMemo(Memo[Dataset]):
def __init__(self, path: str, split: str):
self.path = path
self.split = split

@cache
def __call__(self) -> Dataset:
dataset = load_dataset(self.path)
path = dataset.cache_files[self.split][0]["filename"]
dataset = Dataset.from_file(path)
return dataset

def __str__(self):
return self.path + "/" + self.split


class IndexLabelMapMemo(Memo[dict[int, str]]):
def __init__(self, path: str, split: str):
self.path = path
self.split = split

@cache
def __call__(self) -> dict[int, str]:
dataset = load_dataset(self.path)
path = dataset.cache_files[self.split][0]["filename"]
dataset = Dataset.from_file(path)
label_names: Sequence[str] = dataset.features["ner_tags"].feature.names
id2label = {i: label for i, label in enumerate(label_names)}
return id2label

def __str__(self):
return self.path + "/" + self.split


@cache
def make_dataloader(tokenizer, dataset):
tokenized_dataset = dataset.map(
lambda examples: tokenize_and_align_labels(tokenizer, examples),
batched=True,
remove_columns=dataset.column_names,
)
return DataLoader(
tokenized_dataset,
shuffle=False,
collate_fn=DataCollatorForTokenClassification(tokenizer, padding=True),
batch_size=8,
)


@lru_cache
Expand Down Expand Up @@ -41,39 +128,42 @@ def run_eval_loop(model, dataloader, padding_id=-100):
return true_positives, false_positives, true_negatives, false_negatives


@nnbench.benchmark(tags=("metric", "aggregate"))
def precision(model: Module, test_dataloader: DataLoader, padding_id: int = -100) -> float:
tp, fp, tn, fn = run_eval_loop(model, test_dataloader, padding_id)
precision = tp / (tp + fp + 1e-6)
return torch.mean(precision).item()


parametrize_label = partial(
nnbench.parametrize,
(
{"class_label": "O"},
{"class_label": "B-PER"},
{"class_label": "I-PER"},
{"class_label": "B-ORG"},
{"class_label": "I-ORG"},
{"class_label": "B-LOC"},
{"class_label": "I-LOC"},
{"class_label": "B-MISC"},
{"class_label": "I-MISC"},
),
nnbench.product,
model=[TokenClassificationModelMemo("dslim/distilbert-NER")],
tokenizer=[TokenizerMemo("dslim/distilbert-NER")],
valdata=[ConllValidationMemo(path="conllpp", split="validation")],
index_label_mapping=[IndexLabelMapMemo(path="conllpp", split="validation")],
class_label=["O", "B-PER", "I-PER", "B-ORG", "I-ORG", "B-LOC", "I-LOC", "B-MISC", "I-MISC"],
tags=("metric", "per-class"),
)()


@nnbench.benchmark(tags=("metric", "aggregate"))
def precision(
model: Module,
tokenizer: PreTrainedTokenizerBase,
valdata: Dataset,
padding_id: int = -100,
) -> float:
dataloader = make_dataloader(tokenizer, valdata)
tp, fp, tn, fn = run_eval_loop(model, dataloader, padding_id)
precision = tp / (tp + fp + 1e-6)
return torch.mean(precision).item()


@parametrize_label
def precision_per_class(
class_label: str,
model: Module,
test_dataloader: DataLoader,
tokenizer: PreTrainedTokenizerBase,
valdata: Dataset,
index_label_mapping: dict[int, str],
padding_id: int = -100,
) -> float:
tp, fp, tn, fn = run_eval_loop(model, test_dataloader, padding_id)
dataloader = make_dataloader(tokenizer, valdata)

tp, fp, tn, fn = run_eval_loop(model, dataloader, padding_id)
precision_values = tp / (tp + fp + 1e-6)
for idx, lbl in index_label_mapping.items():
if lbl == class_label:
Expand All @@ -82,8 +172,15 @@ def precision_per_class(


@nnbench.benchmark(tags=("metric", "aggregate"))
def recall(model: Module, test_dataloader: DataLoader, padding_id: int = -100) -> float:
tp, fp, tn, fn = run_eval_loop(model, test_dataloader, padding_id)
def recall(
model: Module,
tokenizer: PreTrainedTokenizerBase,
valdata: Dataset,
padding_id: int = -100,
) -> float:
dataloader = make_dataloader(tokenizer, valdata)

tp, fp, tn, fn = run_eval_loop(model, dataloader, padding_id)
recall = tp / (tp + fn + 1e-6)
return torch.mean(recall).item()

Expand All @@ -92,11 +189,14 @@ def recall(model: Module, test_dataloader: DataLoader, padding_id: int = -100) -
def recall_per_class(
class_label: str,
model: Module,
test_dataloader: DataLoader,
tokenizer: PreTrainedTokenizerBase,
valdata: Dataset,
index_label_mapping: dict[int, str],
padding_id: int = -100,
) -> float:
tp, fp, tn, fn = run_eval_loop(model, test_dataloader, padding_id)
dataloader = make_dataloader(tokenizer, valdata)

tp, fp, tn, fn = run_eval_loop(model, dataloader, padding_id)
recall_values = tp / (tp + fn + 1e-6)
for idx, lbl in index_label_mapping.items():
if lbl == class_label:
Expand All @@ -105,8 +205,15 @@ def recall_per_class(


@nnbench.benchmark(tags=("metric", "aggregate"))
def f1(model: Module, test_dataloader: DataLoader, padding_id: int = -100) -> float:
tp, fp, tn, fn = run_eval_loop(model, test_dataloader, padding_id)
def f1(
model: Module,
tokenizer: PreTrainedTokenizerBase,
valdata: Dataset,
padding_id: int = -100,
) -> float:
dataloader = make_dataloader(tokenizer, valdata)

tp, fp, tn, fn = run_eval_loop(model, dataloader, padding_id)
precision = tp / (tp + fp + 1e-6)
recall = tp / (tp + fn + 1e-6)
f1 = 2 * (precision * recall) / (precision + recall + 1e-6)
Expand All @@ -117,11 +224,14 @@ def f1(model: Module, test_dataloader: DataLoader, padding_id: int = -100) -> fl
def f1_per_class(
class_label: str,
model: Module,
test_dataloader: DataLoader,
tokenizer: PreTrainedTokenizerBase,
valdata: Dataset,
index_label_mapping: dict[int, str],
padding_id: int = -100,
) -> float:
tp, fp, tn, fn = run_eval_loop(model, test_dataloader, padding_id)
dataloader = make_dataloader(tokenizer, valdata)

tp, fp, tn, fn = run_eval_loop(model, dataloader, padding_id)
precision = tp / (tp + fp + 1e-6)
recall = tp / (tp + fn + 1e-6)
f1_values = 2 * (precision * recall) / (precision + recall + 1e-6)
Expand All @@ -132,8 +242,15 @@ def f1_per_class(


@nnbench.benchmark(tags=("metric", "aggregate"))
def accuracy(model: Module, test_dataloader: DataLoader, padding_id: int = -100) -> float:
tp, fp, tn, fn = run_eval_loop(model, test_dataloader, padding_id)
def accuracy(
model: Module,
tokenizer: PreTrainedTokenizerBase,
valdata: Dataset,
padding_id: int = -100,
) -> float:
dataloader = make_dataloader(tokenizer, valdata)

tp, fp, tn, fn = run_eval_loop(model, dataloader, padding_id)
accuracy = (tp + tn) / (tp + tn + fp + fn + 1e-6)
return torch.mean(accuracy).item()

Expand All @@ -142,11 +259,14 @@ def accuracy(model: Module, test_dataloader: DataLoader, padding_id: int = -100)
def accuracy_per_class(
class_label: str,
model: Module,
test_dataloader: DataLoader,
tokenizer: PreTrainedTokenizerBase,
valdata: Dataset,
index_label_mapping: dict[int, str],
padding_id: int = -100,
) -> dict[str, float]:
tp, fp, tn, fn = run_eval_loop(model, test_dataloader, padding_id)
dataloader = make_dataloader(tokenizer, valdata)

tp, fp, tn, fn = run_eval_loop(model, dataloader, padding_id)
accuracy_values = (tp + tn) / (tp + tn + fp + fn + 1e-6)
for idx, lbl in index_label_mapping.items():
if lbl == class_label:
Expand All @@ -162,12 +282,19 @@ def model_configuration(model: Module) -> dict:


@nnbench.benchmark(tags=("model-meta", "inference-time"))
def avg_inference_time_ns(model: Module, test_dataloader: DataLoader, avg_n: int = 100) -> float:
def avg_inference_time_ns(
model: Module,
tokenizer: PreTrainedTokenizerBase,
valdata: Dataset,
avg_n: int = 100,
) -> float:
dataloader = make_dataloader(tokenizer, valdata)

start_time = time.perf_counter()
model.eval()
num_datapoints = 0
with torch.no_grad():
for batch in test_dataloader:
for batch in dataloader:
if num_datapoints >= avg_n:
break
num_datapoints += len(batch)
Expand Down
Loading

0 comments on commit 65fc45b

Please sign in to comment.