Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Log confusion matrices after each epoch #111

Merged
merged 3 commits into from
Feb 8, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# CHANGELOG

### 3.8.0
- dev: log confusion matrices to Comet after each epoch.
- fix: do not mix the two way to log IoUs to avoid known lightning [Common Pitfalls](https://lightning.ai/docs/torchmetrics/stable/pages/lightning.html#common-pitfalls).

### 3.7.1
- fix: edge case when saving predictions under Classification channel, without saving entropy.

Expand Down
15 changes: 14 additions & 1 deletion myria3d/callbacks/comet_callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ def get_comet_logger(trainer: Trainer) -> Optional[CometLogger]:
return logger

warnings.warn(
"You are using comet related callback, but CometLogger was not found for some reason...",
"You are using comet related functions, but trainer has no CometLogger among its loggers.",
UserWarning,
)
return None
Expand Down Expand Up @@ -71,3 +71,16 @@ def setup(self, trainer, pl_module, stage):
log_path = os.getcwd()
log.info(f"----------------\n LOGS DIR is {log_path}\n ----------------")
logger.experiment.log_parameter("experiment_logs_dirpath", log_path)


def log_comet_cm(lightning_module, confmat, phase):
logger = get_comet_logger(trainer=lightning_module)
if logger:
labels = list(lightning_module.hparams.classification_dict.values())
logger.experiment.log_confusion_matrix(
matrix=confmat.cpu().numpy().tolist(),
labels=labels,
file_name=f"{phase}-confusion-matrix",
title="{phase} confusion matrix",
epoch=lightning_module.current_epoch,
)
20 changes: 13 additions & 7 deletions myria3d/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from torch_geometric.data import Batch
from torch_geometric.nn import knn_interpolate
from torchmetrics.classification import MulticlassJaccardIndex
from myria3d.callbacks.comet_callbacks import log_comet_cm

from myria3d.metrics.iou import iou
from myria3d.models.modules.pyg_randla_net import PyGRandLANet
Expand Down Expand Up @@ -143,12 +144,14 @@ def training_step(self, batch: Batch, batch_idx: int) -> dict:
with torch.no_grad():
preds = torch.argmax(logits.detach(), dim=1)
self.train_iou(preds, targets)
self.log("train/iou", self.train_iou, on_step=True, on_epoch=True, prog_bar=True)

return {"loss": loss, "logits": logits, "targets": targets}

def on_train_epoch_end(self) -> None:
self.train_iou.compute()
iou_epoch = self.train_iou.compute()
self.log("train/iou", iou_epoch, on_step=False, on_epoch=True, prog_bar=True)
self.log_all_class_ious(self.train_iou.confmat, "train")
log_comet_cm(self, self.train_iou.confmat, "train")
self.train_iou.reset()

def validation_step(self, batch: Batch, batch_idx: int) -> dict:
Expand All @@ -173,7 +176,7 @@ def validation_step(self, batch: Batch, batch_idx: int) -> dict:
preds = torch.argmax(logits.detach(), dim=1)
self.val_iou = self.val_iou.to(preds.device)
self.val_iou(preds, targets)
self.log("val/iou", self.val_iou, on_step=True, on_epoch=True, prog_bar=True)

return {"loss": loss, "logits": logits, "targets": targets}

def on_validation_epoch_end(self) -> None:
Expand All @@ -183,8 +186,10 @@ def on_validation_epoch_end(self) -> None:
outputs : output of validation_step

"""
self.val_iou.compute()
iou_epoch = self.val_iou.compute()
self.log("val/iou", iou_epoch, on_step=False, on_epoch=True, prog_bar=True)
self.log_all_class_ious(self.val_iou.confmat, "val")
log_comet_cm(self, self.val_iou.confmat, "val")
self.val_iou.reset()

def test_step(self, batch: Batch, batch_idx: int):
Expand All @@ -201,12 +206,11 @@ def test_step(self, batch: Batch, batch_idx: int):
targets, logits = self.forward(batch)
self.criterion = self.criterion.to(logits.device)
loss = self.criterion(logits, targets)
self.log("test/loss", loss, on_step=True, on_epoch=True)
self.log("test/loss", loss, on_step=False, on_epoch=True)

preds = torch.argmax(logits, dim=1)
self.test_iou = self.test_iou.to(preds.device)
self.test_iou(preds, targets)
self.log("test/iou", self.test_iou, on_step=False, on_epoch=True, prog_bar=True)

return {"loss": loss, "logits": logits, "targets": targets}

Expand All @@ -217,8 +221,10 @@ def on_test_epoch_end(self) -> None:
outputs : output of test

"""
self.test_iou.compute()
iou_epoch = self.test_iou.compute()
self.log("test/iou", iou_epoch, on_step=False, on_epoch=True, prog_bar=True)
self.log_all_class_ious(self.test_iou.confmat, "test")
log_comet_cm(self, self.test_iou.confmat, "test")
self.test_iou.reset()

def predict_step(self, batch: Batch) -> dict:
Expand Down
2 changes: 1 addition & 1 deletion package_metadata.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__: "3.7.1"
__version__: "3.8.0"
__name__: "myria3d"
__url__: "https://github.com/IGNF/myria3d"
__description__: "Deep Learning for the Semantic Segmentation of Aerial Lidar Point Clouds"
Expand Down