diff --git a/configs/callbacks/default.yaml b/configs/callbacks/default.yaml index dcab6e52..625f5016 100755 --- a/configs/callbacks/default.yaml +++ b/configs/callbacks/default.yaml @@ -30,3 +30,6 @@ early_stopping: patience: 6 # how many validation epochs of not improving until training stops min_delta: 0 # minimum change in the monitored metric needed to qualify as an improvement +model_detailed_metrics: + _target_: myria3d.callbacks.metric_callbacks.ModelMetrics + num_classes: ${model.num_classes} \ No newline at end of file diff --git a/myria3d/callbacks/comet_callbacks.py b/myria3d/callbacks/comet_callbacks.py index c9694376..d7d20690 100755 --- a/myria3d/callbacks/comet_callbacks.py +++ b/myria3d/callbacks/comet_callbacks.py @@ -73,14 +73,15 @@ def setup(self, trainer, pl_module, stage): logger.experiment.log_parameter("experiment_logs_dirpath", log_path) -def log_comet_cm(lightning_module, confmat, phase): - logger = get_comet_logger(trainer=lightning_module) +def log_comet_cm(pl_module, confmat, phase, class_names): + """Method used in the metric logging callback.""" + logger = get_comet_logger(trainer=pl_module.trainer) if logger: - labels = list(lightning_module.hparams.classification_dict.values()) + class_names = list(pl_module.hparams.classification_dict.values()) logger.experiment.log_confusion_matrix( matrix=confmat.cpu().numpy().tolist(), - labels=labels, + labels=class_names, file_name=f"{phase}-confusion-matrix", title="{phase} confusion matrix", - epoch=lightning_module.current_epoch, + epoch=pl_module.current_epoch, ) diff --git a/myria3d/callbacks/metric_callbacks.py b/myria3d/callbacks/metric_callbacks.py new file mode 100644 index 00000000..d0f467c9 --- /dev/null +++ b/myria3d/callbacks/metric_callbacks.py @@ -0,0 +1,105 @@ +from pytorch_lightning import Callback +import torch +from torchmetrics import Accuracy, F1Score, JaccardIndex, Precision, Recall, ConfusionMatrix + +from myria3d.callbacks.comet_callbacks import log_comet_cm + + +class ModelMetrics(Callback): + """Compute metrics for multiclass classification. + + Accuracy, Precision, Recall, F1Score are micro-averaged. + IoU (Jaccard Index) is macro-average to get the mIoU. + All metrics are also computed per class. + + Be careful when manually computing/reseting metrics. See: + https://lightning.ai/docs/torchmetrics/stable/pages/lightning.html + + """ + + def __init__(self, num_classes=7): + self.num_classes = num_classes + self.metrics = { + "train": self._metrics_factory(), + "val": self._metrics_factory(), + "test": self._metrics_factory(), + } + self.metrics_by_class = { + "train": self._metrics_factory(by_class=True), + "val": self._metrics_factory(by_class=True), + "test": self._metrics_factory(by_class=True), + } + self.cm = ConfusionMatrix(task="multiclass", num_classes=self.num_classes) + + def _metrics_factory(self, by_class=False): + average = None if by_class else "micro" + average_iou = None if by_class else "macro" # special case, only mean IoU is of interest + + return { + "acc": Accuracy(task="multiclass", num_classes=self.num_classes, average=average), + "precision": Precision( + task="multiclass", num_classes=self.num_classes, average=average + ), + "recall": Recall(task="multiclass", num_classes=self.num_classes, average=average), + "f1": F1Score(task="multiclass", num_classes=self.num_classes, average=average), + # DEBUG: checking that this iou matches the one from model.py before removing it + "iou": JaccardIndex( + task="multiclass", num_classes=self.num_classes, average=average_iou + ), + } + + def _end_of_batch(self, phase: str, outputs): + targets = outputs["targets"] + preds = torch.argmax(outputs["logits"].detach(), dim=1) + for m in self.metrics[phase].values(): + m.to(preds.device)(preds, targets) + for m in self.metrics_by_class[phase].values(): + m.to(preds.device)(preds, targets) + self.cm.to(preds.device)(preds, targets) + + def _end_of_epoch(self, phase: str, pl_module): + for metric_name, metric in self.metrics[phase].items(): + metric_name_for_log = f"{phase}/{metric_name}" + value = metric.to(pl_module.device).compute() + self.log( + metric_name_for_log, + value, + on_epoch=True, + on_step=False, + metric_attribute=metric_name_for_log, + ) + metric.reset() # always reset state when using compute(). + + class_names = pl_module.hparams.classification_dict.values() + for metric_name, metric in self.metrics_by_class[phase].items(): + values = metric.to(pl_module.device).compute() + for value, class_name in zip(values, class_names): + metric_name_for_log = f"{phase}/{metric_name}/{class_name}" + self.log( + metric_name_for_log, + value, + on_step=False, + on_epoch=True, + metric_attribute=metric_name_for_log, + ) + metric.reset() # always reset state when using compute(). + + log_comet_cm(pl_module, self.cm.confmat, phase, class_names) + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + self._end_of_batch("train", outputs) + + def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + self._end_of_batch("val", outputs) + + def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): + self._end_of_batch("test", outputs) + + def on_train_epoch_end(self, trainer, pl_module): + self._end_of_epoch("train", pl_module) + + def on_val_epoch_end(self, trainer, pl_module): + self._end_of_epoch("val", pl_module) + + def on_test_epoch_end(self, trainer, pl_module): + self._end_of_epoch("test", pl_module) diff --git a/myria3d/metrics/iou.py b/myria3d/metrics/iou.py deleted file mode 100644 index f9281b37..00000000 --- a/myria3d/metrics/iou.py +++ /dev/null @@ -1,21 +0,0 @@ -from torch import Tensor - -EPSILON = 1e-8 - - -def iou(confmat: Tensor): - """Computes the Intersection over Union of each class in the - confusion matrix - - Return: - (iou, missing_class_mask) - iou for class as well as a mask - highlighting existing classes - """ - true_positives_and_false_negatives = confmat.sum(dim=0) - true_positives_and_false_positives = confmat.sum(dim=1) - true_positives = confmat.diag() - union = ( - true_positives_and_false_negatives + true_positives_and_false_positives - true_positives - ) - iou = EPSILON + true_positives / (union + EPSILON) - return iou diff --git a/myria3d/models/model.py b/myria3d/models/model.py index 3dbd5fc5..f1d842d7 100755 --- a/myria3d/models/model.py +++ b/myria3d/models/model.py @@ -3,10 +3,7 @@ from torch import nn from torch_geometric.data import Batch from torch_geometric.nn import knn_interpolate -from torchmetrics.classification import MulticlassJaccardIndex -from myria3d.callbacks.comet_callbacks import log_comet_cm -from myria3d.metrics.iou import iou from myria3d.models.modules.pyg_randla_net import PyGRandLANet from myria3d.utils import utils @@ -33,14 +30,12 @@ def get_neural_net_class(class_name: str) -> nn.Module: class Model(LightningModule): - """This LightningModule implements the logic for model trainin, validation, tests, and prediction. + """Model training, validation, test and prediction of point cloud semantic segmentation. - It is fully initialized by named parameters for maximal flexibility with hydra configs. + During training and validation, metrics are calculed based on sumbsampled points only. + At test time, metrics are calculated considering all the points. - During training and validation, IoU is calculed based on sumbsampled points only, and is therefore - an approximation. - At test time, IoU is calculated considering all the points. To keep this module light, a callback - takes care of the interpolation of predictions between all points. + To keep this module light, a callback takes care of metric computations. Read the Pytorch Lightning docs: @@ -51,7 +46,7 @@ class Model(LightningModule): def __init__(self, **kwargs): """Initialization method of the Model lightning module. - Everything needed to train/test/predict with a neural architecture, including + Everything needed to train/evaluate/test/predict with a neural architecture, including the architecture class name and its hyperparameter. See config files for a list of kwargs. @@ -69,22 +64,6 @@ def __init__(self, **kwargs): self.softmax = nn.Softmax(dim=1) self.criterion = kwargs.get("criterion") - def on_fit_start(self) -> None: - self.criterion = self.criterion.to(self.device) - self.train_iou = MulticlassJaccardIndex(self.hparams.num_classes).to(self.device) - self.val_iou = MulticlassJaccardIndex(self.hparams.num_classes).to(self.device) - - def on_test_start(self) -> None: - self.test_iou = MulticlassJaccardIndex(self.hparams.num_classes).to(self.device) - - def log_all_class_ious(self, confmat, phase: str): - ious = iou(confmat) - for class_iou, class_name in zip(ious, self.hparams.classification_dict.values()): - metric_name = f"{phase}/iou_CLASS_{class_name}" - self.log( - metric_name, class_iou, on_step=False, on_epoch=True, metric_attribute=metric_name - ) - def forward(self, batch: Batch) -> torch.Tensor: """Forward pass of neural network. @@ -126,8 +105,6 @@ def forward(self, batch: Batch) -> torch.Tensor: def training_step(self, batch: Batch, batch_idx: int) -> dict: """Training step. - Makes a model pass. Then, computes loss and predicted class of subsampled points to log loss and IoU. - Args: batch (torch_geometric.data.Batch): Batch of data including x (features), pos (xyz positions), and y (targets, optionnal) in (B*N,C) format. @@ -140,25 +117,11 @@ def training_step(self, batch: Batch, batch_idx: int) -> dict: self.criterion = self.criterion.to(logits.device) loss = self.criterion(logits, targets) self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=False) - - with torch.no_grad(): - preds = torch.argmax(logits.detach(), dim=1) - self.train_iou(preds, targets) - return {"loss": loss, "logits": logits, "targets": targets} - def on_train_epoch_end(self) -> None: - iou_epoch = self.train_iou.to(self.device).compute() - self.log("train/iou", iou_epoch, on_step=False, on_epoch=True, prog_bar=True) - self.log_all_class_ious(self.train_iou.confmat, "train") - log_comet_cm(self, self.train_iou.confmat, "train") - self.train_iou.reset() - def validation_step(self, batch: Batch, batch_idx: int) -> dict: """Validation step. - Makes a model pass. Then, computes loss and predicted class of subsampled points to log loss and IoU. - Args: batch (torch_geometric.data.Batch): Batch of data including x (features), pos (xyz positions), and y (targets, optionnal) in (B*N,C) format. @@ -172,26 +135,8 @@ def validation_step(self, batch: Batch, batch_idx: int) -> dict: self.criterion = self.criterion.to(logits.device) loss = self.criterion(logits, targets) self.log("val/loss", loss, on_step=True, on_epoch=True) - - preds = torch.argmax(logits.detach(), dim=1) - self.val_iou = self.val_iou.to(preds.device) - self.val_iou(preds, targets) - return {"loss": loss, "logits": logits, "targets": targets} - def on_validation_epoch_end(self) -> None: - """At the end of a validation epoch, compute the IoU. - - Args: - outputs : output of validation_step - - """ - iou_epoch = self.val_iou.to(self.device).compute() - self.log("val/iou", iou_epoch, on_step=False, on_epoch=True, prog_bar=True) - self.log_all_class_ious(self.val_iou.confmat, "val") - log_comet_cm(self, self.val_iou.confmat, "val") - self.val_iou.reset() - def test_step(self, batch: Batch, batch_idx: int): """Test step. @@ -207,26 +152,8 @@ def test_step(self, batch: Batch, batch_idx: int): self.criterion = self.criterion.to(logits.device) loss = self.criterion(logits, targets) self.log("test/loss", loss, on_step=False, on_epoch=True) - - preds = torch.argmax(logits, dim=1) - self.test_iou = self.test_iou.to(preds.device) - self.test_iou(preds, targets) - return {"loss": loss, "logits": logits, "targets": targets} - def on_test_epoch_end(self) -> None: - """At the end of a validation epoch, compute the IoU. - - Args: - outputs : output of test - - """ - iou_epoch = self.test_iou.to(self.device).compute() - self.log("test/iou", iou_epoch, on_step=False, on_epoch=True, prog_bar=True) - self.log_all_class_ious(self.test_iou.confmat, "test") - log_comet_cm(self, self.test_iou.confmat, "test") - self.test_iou.reset() - def predict_step(self, batch: Batch) -> dict: """Prediction step.