From 363c2586bfda592a7ebb6b559e1772cc7aa6f43e Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Thu, 21 Mar 2024 09:08:40 +0100 Subject: [PATCH] print results table at the end of an evaluation session --- src/eva/core/trainers/_recorder.py | 65 ++++++++++++++++++++++++++--- src/eva/core/trainers/functional.py | 2 +- 2 files changed, 61 insertions(+), 6 deletions(-) diff --git a/src/eva/core/trainers/_recorder.py b/src/eva/core/trainers/_recorder.py index 64950729..783e05f8 100644 --- a/src/eva/core/trainers/_recorder.py +++ b/src/eva/core/trainers/_recorder.py @@ -5,18 +5,41 @@ import os import statistics import sys -from typing import Any, Dict, List, Mapping +from typing import Any, Dict, List, Mapping, TypedDict from lightning.pytorch.utilities.types import _EVALUATE_OUTPUT from lightning_fabric.utilities import cloud_io from loguru import logger from omegaconf import OmegaConf +from rich import console as rich_console +from rich import table as rich_table from toolz import dicttoolz SESSION_METRICS = Mapping[str, List[float]] """Session metrics type-hint.""" +class SESSION_STATISTICS(TypedDict): + """Type-hint for aggregated metrics of multiple runs with mean & stdev.""" + + mean: float + stdev: float + values: List[float] + + +class STAGE_RESULTS(TypedDict): + """Type-hint for metrics statstics for val & test stages.""" + + val: List[Dict[str, SESSION_STATISTICS]] + test: List[Dict[str, SESSION_STATISTICS]] + + +class RESULTS_DICT(TypedDict): + """Type-hint for the final results dictionary.""" + + metrics: STAGE_RESULTS + + class SessionRecorder: """Multi-run (session) summary logger.""" @@ -67,13 +90,13 @@ def update( self._update_validation_metrics(validation_scores) self._update_test_metrics(test_scores) - def compute(self) -> Dict[str, List[Dict[str, Any]]]: + def compute(self) -> STAGE_RESULTS: """Computes and returns the session statistics.""" validation_statistics = list(map(_calculate_statistics, self._validation_metrics)) test_statistics = list(map(_calculate_statistics, self._test_metrics)) return {"val": validation_statistics, "test": test_statistics} - def export(self) -> Dict[str, Any]: + def export(self) -> RESULTS_DICT: """Exports the results.""" statistics = self.compute() return {"metrics": statistics} @@ -81,6 +104,7 @@ def export(self) -> Dict[str, Any]: def save(self) -> None: """Saves the recorded results.""" results = self.export() + _print_results(results) _save_json(results, self.filename) self._save_config() @@ -125,10 +149,10 @@ def _init_session_metrics(n_datasets: int) -> List[SESSION_METRICS]: return [collections.defaultdict(list) for _ in range(n_datasets)] -def _calculate_statistics(session_metrics: SESSION_METRICS) -> Dict[str, float | List[float]]: +def _calculate_statistics(session_metrics: SESSION_METRICS) -> Dict[str, SESSION_STATISTICS]: """Calculate the metric statistics of a dataset session run.""" - def _calculate_metric_statistics(values: List[float]) -> Dict[str, float | List[float]]: + def _calculate_metric_statistics(values: List[float]) -> SESSION_STATISTICS: """Calculates and returns the metric statistics.""" mean = statistics.mean(values) stdev = statistics.stdev(values) if len(values) > 1 else 0 @@ -147,3 +171,34 @@ def _save_json(data: Dict[str, Any], save_as: str = "data.json"): fs.makedirs(output_dir, exist_ok=True) with fs.open(save_as, "w") as file: json.dump(data, file, indent=4, sort_keys=True) + + +def _print_results(results: RESULTS_DICT) -> None: + """Prints the results to the console.""" + for stage in ["val", "test"]: + for dataset_idx in range(len(results["metrics"][stage])): + _print_table(results["metrics"][stage][dataset_idx], stage, dataset_idx) + + +def _print_table(metrics_dict: Dict[str, SESSION_STATISTICS], stage: str, dataset_idx: int): + """Prints the metrics of a single dataset as a table.""" + metrics_table = rich_table.Table( + title=f"\n{stage.capitalize()} Dataset {dataset_idx}", title_style="bold" + ) + metrics_table.add_column("Metric", style="cyan") + metrics_table.add_column("Mean", justify="right", style="magenta") + metrics_table.add_column("Stdev", justify="right", style="magenta") + + n_runs = len(metrics_dict[next(iter(metrics_dict))]["values"]) + for i in range(n_runs): + metrics_table.add_column(f"Run {i}", justify="right", style="magenta") + + for metric_name, metric_dict in metrics_dict.items(): + row = [metric_name, metric_dict["mean"], metric_dict["stdev"]] + [ + metric_dict["values"][i] for i in range(n_runs) + ] + row = [str(entry) for entry in row] + metrics_table.add_row(*row) + + console = rich_console.Console() + console.print(metrics_table) diff --git a/src/eva/core/trainers/functional.py b/src/eva/core/trainers/functional.py index 00f81f5e..9e630524 100644 --- a/src/eva/core/trainers/functional.py +++ b/src/eva/core/trainers/functional.py @@ -82,7 +82,7 @@ def fit_and_validate( A tuple of with the validation and the test metrics (if exists). """ trainer.fit(model, datamodule=datamodule) - validation_scores = trainer.validate(datamodule=datamodule) + validation_scores = trainer.validate(datamodule=datamodule, verbose=False) test_scores = None if datamodule.datasets.test is None else trainer.test(datamodule=datamodule) return validation_scores, test_scores