Skip to content

Commit

Permalink
print results table at the end of an evaluation session
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaenzig committed Mar 21, 2024
1 parent 7375768 commit 363c258
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 6 deletions.
65 changes: 60 additions & 5 deletions src/eva/core/trainers/_recorder.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,41 @@
import os
import statistics
import sys
from typing import Any, Dict, List, Mapping
from typing import Any, Dict, List, Mapping, TypedDict

from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT
from lightning_fabric.utilities import cloud_io
from loguru import logger
from omegaconf import OmegaConf
from rich import console as rich_console
from rich import table as rich_table
from toolz import dicttoolz

SESSION_METRICS = Mapping[str, List[float]]
"""Session metrics type-hint."""


class SESSION_STATISTICS(TypedDict):
"""Type-hint for aggregated metrics of multiple runs with mean & stdev."""

mean: float
stdev: float
values: List[float]


class STAGE_RESULTS(TypedDict):
"""Type-hint for metrics statstics for val & test stages."""

val: List[Dict[str, SESSION_STATISTICS]]
test: List[Dict[str, SESSION_STATISTICS]]


class RESULTS_DICT(TypedDict):
"""Type-hint for the final results dictionary."""

metrics: STAGE_RESULTS


class SessionRecorder:
"""Multi-run (session) summary logger."""

Expand Down Expand Up @@ -67,20 +90,21 @@ def update(
self._update_validation_metrics(validation_scores)
self._update_test_metrics(test_scores)

def compute(self) -> Dict[str, List[Dict[str, Any]]]:
def compute(self) -> STAGE_RESULTS:
"""Computes and returns the session statistics."""
validation_statistics = list(map(_calculate_statistics, self._validation_metrics))
test_statistics = list(map(_calculate_statistics, self._test_metrics))
return {"val": validation_statistics, "test": test_statistics}

def export(self) -> Dict[str, Any]:
def export(self) -> RESULTS_DICT:
"""Exports the results."""
statistics = self.compute()
return {"metrics": statistics}

def save(self) -> None:
"""Saves the recorded results."""
results = self.export()
_print_results(results)
_save_json(results, self.filename)
self._save_config()

Expand Down Expand Up @@ -125,10 +149,10 @@ def _init_session_metrics(n_datasets: int) -> List[SESSION_METRICS]:
return [collections.defaultdict(list) for _ in range(n_datasets)]


def _calculate_statistics(session_metrics: SESSION_METRICS) -> Dict[str, float | List[float]]:
def _calculate_statistics(session_metrics: SESSION_METRICS) -> Dict[str, SESSION_STATISTICS]:
"""Calculate the metric statistics of a dataset session run."""

def _calculate_metric_statistics(values: List[float]) -> Dict[str, float | List[float]]:
def _calculate_metric_statistics(values: List[float]) -> SESSION_STATISTICS:
"""Calculates and returns the metric statistics."""
mean = statistics.mean(values)
stdev = statistics.stdev(values) if len(values) > 1 else 0
Expand All @@ -147,3 +171,34 @@ def _save_json(data: Dict[str, Any], save_as: str = "data.json"):
fs.makedirs(output_dir, exist_ok=True)
with fs.open(save_as, "w") as file:
json.dump(data, file, indent=4, sort_keys=True)


def _print_results(results: RESULTS_DICT) -> None:
"""Prints the results to the console."""
for stage in ["val", "test"]:
for dataset_idx in range(len(results["metrics"][stage])):
_print_table(results["metrics"][stage][dataset_idx], stage, dataset_idx)


def _print_table(metrics_dict: Dict[str, SESSION_STATISTICS], stage: str, dataset_idx: int):
"""Prints the metrics of a single dataset as a table."""
metrics_table = rich_table.Table(
title=f"\n{stage.capitalize()} Dataset {dataset_idx}", title_style="bold"
)
metrics_table.add_column("Metric", style="cyan")
metrics_table.add_column("Mean", justify="right", style="magenta")
metrics_table.add_column("Stdev", justify="right", style="magenta")

n_runs = len(metrics_dict[next(iter(metrics_dict))]["values"])
for i in range(n_runs):
metrics_table.add_column(f"Run {i}", justify="right", style="magenta")

for metric_name, metric_dict in metrics_dict.items():
row = [metric_name, metric_dict["mean"], metric_dict["stdev"]] + [
metric_dict["values"][i] for i in range(n_runs)
]
row = [str(entry) for entry in row]
metrics_table.add_row(*row)

console = rich_console.Console()
console.print(metrics_table)
2 changes: 1 addition & 1 deletion src/eva/core/trainers/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def fit_and_validate(
A tuple of with the validation and the test metrics (if exists).
"""
trainer.fit(model, datamodule=datamodule)
validation_scores = trainer.validate(datamodule=datamodule)
validation_scores = trainer.validate(datamodule=datamodule, verbose=False)
test_scores = None if datamodule.datasets.test is None else trainer.test(datamodule=datamodule)
return validation_scores, test_scores

Expand Down

0 comments on commit 363c258

Please sign in to comment.