-
Notifications
You must be signed in to change notification settings - Fork 24
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'main' into edit-contact-info-cg
- Loading branch information
Showing
8 changed files
with
236 additions
and
104 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
110 changes: 110 additions & 0 deletions
110
.github/workflows/predict-for-lidar-prod-optimization.yml
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,110 @@ | ||
# Workflow name | ||
name: "Prediction on lidar-prod optimization dataset" | ||
|
||
on: | ||
# Run workflow on user request | ||
workflow_dispatch: | ||
inputs: | ||
user: | ||
description: | | ||
Username : | ||
Utilisé pour générer un chemin standard pour les sorties dans le | ||
dossier IA du store (projet-LHD/IA/MYRIA3D-SHARED-WORKSPACE/$USER/$SAMPLING_NAME/) | ||
required: true | ||
sampling_name: | ||
description: | | ||
Sampling name : | ||
Nom du dataset sur lequel le modèle a été entraîné. | ||
Utilisé pour générer un chemin standard pour les sorties dans le | ||
dossier IA du store (projet-LHD/IA/MYRIA3D-SHARED-WORKSPACE/$USER/$SAMPLING_NAME/) | ||
Eg. YYYYMMDD_MonBeauDataset | ||
required: true | ||
model_id: | ||
description: | | ||
Identifiant du modèle : | ||
Il correspond au nom du fichier checkpoint à utiliser pour les prédictions (sans l'extension .ckpt !) | ||
($MODEL_ID.ckpt doit exister dans projet-LHD/IA/MYRIA3D-SHARED-WORKSPACE/$USER/$SAMPLING_NAME/) | ||
Il est aussi utilisé pour générer le dossier de sortie | ||
(projet-LHD/IA/LIDAR-PROD-OPTIMIZATION/$SAMPLING_NAME/$MODEL_ID) | ||
Exemple : YYYMMDD_MonBeauSampling_epochXXX_Myria3Dx.y.z | ||
required: true | ||
predict_config_name: | ||
description: | | ||
Nom du fichier de config de myria3d (fichier .yaml) à utiliser pour la prédiction | ||
(doit exister dans projet-LHD/IA/MYRIA3D-SHARED-WORKSPACE/$USER/$SAMPLING_NAME/) | ||
Exemple: YYYMMDD_MonBeauSampling_epochXXX_Myria3Dx.y.z_predict_config_Vx.y.z.yaml | ||
required: true | ||
|
||
jobs: | ||
predict-validation-dataset: | ||
runs-on: self-hosted | ||
env: | ||
OUTPUT_DIR: /var/data/LIDAR-PROD-OPTIMIZATION/${{ github.event.inputs.sampling_name }}/${{ github.event.inputs.model_id }}/ | ||
DATA: /var/data/LIDAR-PROD-OPTIMIZATION/20221018_lidar-prod-optimization-on-151-proto/Comparison/ | ||
CONFIG_DIR: /var/data/MYRIA3D-SHARED-WORKSPACE/${{ github.event.inputs.user }}/${{ github.event.inputs.sampling_name }}/ | ||
BATCH_SIZE: 25 | ||
|
||
steps: | ||
- name: Log configuration | ||
run: | | ||
echo "Run prediction on lidar-prod optimization datasets (val and test)" | ||
echo "Sampling name: ${{ github.event.inputs.sampling_name }}" | ||
echo "User name: ${{ github.event.inputs.user }}" | ||
echo "Checkpoint name: ${{ github.event.inputs.model_id }}" | ||
echo "Prediction config name: ${{ github.event.inputs.predict_config_name }}" | ||
echo "Output_dir: ${{env.OUTPUT_DIR}}" | ||
echo "Data: ${{env.DATA}}" | ||
echo "Config files dir: ${{env.CONFIG_DIR}}" | ||
- name: Checkout branch | ||
uses: actions/checkout@v4 | ||
|
||
# get version number, to retrieve the docker image corresponding to the current version | ||
- name: Get version number | ||
run: | | ||
echo "VERSION=$(docker run myria3d python -m myria3d._version)" >> $GITHUB_ENV | ||
- name: pull docker image tagged with current version | ||
run: | | ||
docker login ${{ secrets.DOCKER_REGISTRY }} --username svc_lidarhd --password ${{ secrets.PASSWORD_SVC_LIDARHD }} | ||
docker pull ${{ secrets.DOCKER_REGISTRY }}/lidar_hd/myria3d:${{ env.VERSION }} | ||
- name: Run prediction on validation dataset | ||
run: > | ||
docker run --network host | ||
--shm-size='28g' | ||
-v ${{env.OUTPUT_DIR}}:/output_dir | ||
-v ${{env.DATA}}:/data | ||
-v ${{env.CONFIG_DIR}}:/config_dir | ||
${{ secrets.DOCKER_REGISTRY }}/lidar_hd/myria3d:${{ env.VERSION }} | ||
python run.py | ||
--config-path /config_dir | ||
--config-name ${{ github.event.inputs.predict_config_name }} | ||
task.task_name=predict | ||
predict.src_las=/data/val/*.laz | ||
predict.ckpt_path=/config_dir/${{ github.event.inputs.model_id }}.ckpt | ||
predict.output_dir=/output_dir/preds-valset/ | ||
predict.interpolator.probas_to_save=[building] | ||
predict.gpus=0 | ||
datamodule.batch_size=${{env.BATCH_SIZE}} | ||
datamodule.tile_width=1000 | ||
- name: Run prediction on test dataset | ||
run: > | ||
docker run --network host | ||
--shm-size='28g' | ||
-v ${{env.OUTPUT_DIR}}:/output_dir | ||
-v ${{env.DATA}}:/data | ||
-v ${{env.CONFIG_DIR}}:/config_dir | ||
${{ secrets.DOCKER_REGISTRY }}/lidar_hd/myria3d:${{ env.VERSION }} | ||
python run.py | ||
--config-path /config_dir | ||
--config-name ${{ github.event.inputs.predict_config_name }} | ||
task.task_name=predict | ||
predict.src_las=/data/test/*.laz | ||
predict.ckpt_path=/config_dir/${{ github.event.inputs.model_id }}.ckpt | ||
predict.output_dir=/output_dir/preds-testset/ | ||
predict.interpolator.probas_to_save=[building] | ||
predict.gpus=0 | ||
datamodule.batch_size=${{env.BATCH_SIZE}} | ||
datamodule.tile_width=1000 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,105 @@ | ||
from pytorch_lightning import Callback | ||
import torch | ||
from torchmetrics import Accuracy, F1Score, JaccardIndex, Precision, Recall, ConfusionMatrix | ||
|
||
from myria3d.callbacks.comet_callbacks import log_comet_cm | ||
|
||
|
||
class ModelMetrics(Callback): | ||
"""Compute metrics for multiclass classification. | ||
Accuracy, Precision, Recall, F1Score are micro-averaged. | ||
IoU (Jaccard Index) is macro-average to get the mIoU. | ||
All metrics are also computed per class. | ||
Be careful when manually computing/reseting metrics. See: | ||
https://lightning.ai/docs/torchmetrics/stable/pages/lightning.html | ||
""" | ||
|
||
def __init__(self, num_classes=7): | ||
self.num_classes = num_classes | ||
self.metrics = { | ||
"train": self._metrics_factory(), | ||
"val": self._metrics_factory(), | ||
"test": self._metrics_factory(), | ||
} | ||
self.metrics_by_class = { | ||
"train": self._metrics_factory(by_class=True), | ||
"val": self._metrics_factory(by_class=True), | ||
"test": self._metrics_factory(by_class=True), | ||
} | ||
self.cm = ConfusionMatrix(task="multiclass", num_classes=self.num_classes) | ||
|
||
def _metrics_factory(self, by_class=False): | ||
average = None if by_class else "micro" | ||
average_iou = None if by_class else "macro" # special case, only mean IoU is of interest | ||
|
||
return { | ||
"acc": Accuracy(task="multiclass", num_classes=self.num_classes, average=average), | ||
"precision": Precision( | ||
task="multiclass", num_classes=self.num_classes, average=average | ||
), | ||
"recall": Recall(task="multiclass", num_classes=self.num_classes, average=average), | ||
"f1": F1Score(task="multiclass", num_classes=self.num_classes, average=average), | ||
# DEBUG: checking that this iou matches the one from model.py before removing it | ||
"iou": JaccardIndex( | ||
task="multiclass", num_classes=self.num_classes, average=average_iou | ||
), | ||
} | ||
|
||
def _end_of_batch(self, phase: str, outputs): | ||
targets = outputs["targets"] | ||
preds = torch.argmax(outputs["logits"].detach(), dim=1) | ||
for m in self.metrics[phase].values(): | ||
m.to(preds.device)(preds, targets) | ||
for m in self.metrics_by_class[phase].values(): | ||
m.to(preds.device)(preds, targets) | ||
self.cm.to(preds.device)(preds, targets) | ||
|
||
def _end_of_epoch(self, phase: str, pl_module): | ||
for metric_name, metric in self.metrics[phase].items(): | ||
metric_name_for_log = f"{phase}/{metric_name}" | ||
value = metric.to(pl_module.device).compute() | ||
self.log( | ||
metric_name_for_log, | ||
value, | ||
on_epoch=True, | ||
on_step=False, | ||
metric_attribute=metric_name_for_log, | ||
) | ||
metric.reset() # always reset state when using compute(). | ||
|
||
class_names = pl_module.hparams.classification_dict.values() | ||
for metric_name, metric in self.metrics_by_class[phase].items(): | ||
values = metric.to(pl_module.device).compute() | ||
for value, class_name in zip(values, class_names): | ||
metric_name_for_log = f"{phase}/{metric_name}/{class_name}" | ||
self.log( | ||
metric_name_for_log, | ||
value, | ||
on_step=False, | ||
on_epoch=True, | ||
metric_attribute=metric_name_for_log, | ||
) | ||
metric.reset() # always reset state when using compute(). | ||
|
||
log_comet_cm(pl_module, self.cm.confmat, phase, class_names) | ||
|
||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | ||
self._end_of_batch("train", outputs) | ||
|
||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | ||
self._end_of_batch("val", outputs) | ||
|
||
def on_test_batch_end(self, trainer, pl_module, outputs, batch, batch_idx): | ||
self._end_of_batch("test", outputs) | ||
|
||
def on_train_epoch_end(self, trainer, pl_module): | ||
self._end_of_epoch("train", pl_module) | ||
|
||
def on_val_epoch_end(self, trainer, pl_module): | ||
self._end_of_epoch("val", pl_module) | ||
|
||
def on_test_epoch_end(self, trainer, pl_module): | ||
self._end_of_epoch("test", pl_module) |
This file was deleted.
Oops, something went wrong.
Oops, something went wrong.