Skip to content

Commit

Permalink
support pl module test step (#344)
Browse files Browse the repository at this point in the history
* support passing losses during test step

* use test_metrics if provided

---------

Co-authored-by: Sivan Ravid <[email protected]>
  • Loading branch information
sivanravidos and Sivan Ravid authored Mar 11, 2024
1 parent 1b50c1d commit 390b1b3
Showing 1 changed file with 46 additions and 21 deletions.
67 changes: 46 additions & 21 deletions fuse/dl/lightning/pl_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,12 @@ def __init__(
List[Tuple[str, OrderedDict[str, MetricBase]]],
]
] = None,
test_metrics: Optional[OrderedDict[str, MetricBase]] = None,
test_metrics: Optional[
Union[
OrderedDict[str, MetricBase],
List[Tuple[str, OrderedDict[str, MetricBase]]],
]
] = None,
optimizers_and_lr_schs: Any = None,
callbacks: Optional[Sequence[pl.Callback]] = None,
best_epoch_source: Optional[Union[Dict, List[Dict]]] = None,
Expand All @@ -67,7 +72,9 @@ def __init__(
:param validation_metrics: ordereddict of FuseMedML style metrics - used for validation set (must be different instances of metrics (from train_metrics!)
In case of multiple validation dataloaders, validation_metrics should be list of tuples (that keeps the same dataloaders list order),
Each tuple built from validation dataloader name and corresponding metrics dict.
:param test_metrics: dict of FuseMedML style metrics - used for test set (must be different instances of metrics (from train_metrics and validation_metrics!)
:param test_metrics: ordereddict of FuseMedML style metrics - used for test set (must be different instances of metrics (from train_metrics)
In case of multiple test dataloaders, test_metrics should be list of tuples (that keeps the same dataloaders list order),
Each tuple built from test dataloader name and corresponding metrics dict.
:param optimizers_and_lr_schs: see pl.LightningModule.configure_optimizers return value for all options
:param callbacks: see pl.LightningModule.configure_callbacks return value for details
:param best_epoch_source: Create list of pl.callbacks that saves checkpoints using (pl.callbacks.ModelCheckpoint) and print per epoch summary (fuse.dl.lightning.pl_epoch_summary.ModelEpochSummary).
Expand Down Expand Up @@ -138,12 +145,14 @@ def __init__(
self._validation_metrics = (
validation_metrics if validation_metrics is not None else {}
)

if test_metrics is None:
self._test_metrics = self._validation_metrics
else:
self._test_metrics = test_metrics
# convert all use-cases to the same format that supports multiple val dataloaders: List[Tuple[str, OrderedDict[str, MetricBase]]]
if isinstance(self._validation_metrics, dict):
self._validation_metrics = [(None, self._validation_metrics)]

self._test_metrics = test_metrics if test_metrics is not None else {}
if log_unit not in [None, "optimizer_step", "epoch"]:
raise Exception(f"Error: unexpected log_unit {log_unit}")

Expand All @@ -158,11 +167,12 @@ def __init__(
self._prediction_keys = None
self._sep = tensorboard_sep

self._training_step_outputs = []

self._validation_step_outputs = {
i: [] for i, _ in enumerate(self._validation_metrics)
}
self._training_step_outputs = []
self._test_step_outputs = []
self._test_step_outputs = {i: [] for i, _ in enumerate(self._test_metrics)}

## forward
def forward(self, batch_dict: NDict) -> NDict:
Expand Down Expand Up @@ -204,17 +214,27 @@ def validation_step(
{"losses": batch_dict["losses"]}
)

def test_step(self, batch_dict: NDict, batch_idx: int) -> None:
def test_step(
self, batch_dict: NDict, batch_idx: int, dataloader_idx: int = 0
) -> None:
# add step number to batch_dict
batch_dict["global_step"] = self.global_step
# run forward function and store the outputs in batch_dict["model"]
batch_dict = self.forward(batch_dict)
# given the batch_dict and FuseMedML style losses - compute the losses, return the total loss (ignored) and save losses values in batch_dict["losses"]
_ = step_losses(self._losses, batch_dict)
if self._validation_losses is not None:
losses = self._validation_losses[dataloader_idx][1]
else:
losses = self._losses

_ = step_losses(losses, batch_dict)
# given the batch_dict and FuseMedML style metrics - collect the required values to compute the metrics on epoch_end
step_metrics(self._test_metrics, batch_dict)
step_metrics(self._test_metrics[dataloader_idx][1], batch_dict)
# aggregate losses
self._test_step_outputs.append({"losses": batch_dict["losses"]})
if losses: # if there are losses, collect the results
self._test_step_outputs[dataloader_idx].append(
{"losses": batch_dict["losses"]}
)

def predict_step(self, batch_dict: NDict, batch_idx: int) -> dict:
if self._prediction_keys is None:
Expand Down Expand Up @@ -267,20 +287,25 @@ def on_validation_epoch_end(self) -> None:
}

def on_test_epoch_end(self) -> None:
step_outputs = self._test_step_outputs
step_outputs_lst = self._test_step_outputs
# for the logs to be at each epoch, not each step
if self._log_unit == "epoch":
self.log("step", self.current_epoch, on_epoch=True, sync_dist=True)
# calc average epoch loss and log it
epoch_end_compute_and_log_losses(
self, "test", [e["losses"] for e in step_outputs], sep=self._sep
)
# evaluate and log it
epoch_end_compute_and_log_metrics(
self, "test", self._test_metrics, sep=self._sep
)
self.log("step", float(self.current_epoch), on_epoch=True, sync_dist=True)
for dataloader_idx, step_outputs in step_outputs_lst.items():
if len(self._test_metrics) == 1:
prefix = "test"
else:
prefix = f"test.{self._test_metrics[dataloader_idx][0]}"
# calc average epoch loss and log it
epoch_end_compute_and_log_losses(
self, prefix, [e["losses"] for e in step_outputs], sep=self._sep
)
# evaluate and log it
epoch_end_compute_and_log_metrics(
self, prefix, self._test_metrics[dataloader_idx][1], sep=self._sep
)
# reset state
self._test_step_outputs.clear()
self._test_step_outputs = {i: [] for i, _ in enumerate(self._test_metrics)}

# configuration
def configure_callbacks(self) -> Sequence[pl.Callback]:
Expand Down

0 comments on commit 390b1b3

Please sign in to comment.