Skip to content

Commit

Permalink
Merge pull request #15 from axondeepseg/feature/inference_and_evaluat…
Browse files Browse the repository at this point in the history
…ion/individual_folds

Enhanced Inference and Evaluation Capabilities for nnU-Net Model
  • Loading branch information
ArthurBoschet authored Feb 22, 2024
2 parents fd59988 + 94abee5 commit 3bcc401
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 19 deletions.
27 changes: 19 additions & 8 deletions nnunet_scripts/inference_and_evaluation.sh
Original file line number Diff line number Diff line change
Expand Up @@ -4,26 +4,37 @@
NNUNET_PATH=$1
DATASET_NAME="Dataset444_AGG"

if [ ! -d "${NNUNET_PATH}/nnUNet_results/${DATASET_NAME}/scores" ]; then
echo "Making directory ${NNUNET_PATH}/nnUNet_results/${DATASET_NAME}/scores"
mkdir "${NNUNET_PATH}/nnUNet_results/${DATASET_NAME}/scores"
fi

shift
for dataset in "$@"; do
echo -e "\n\nRun inference with $NNUNET_PATH/$dataset model on $DATASET_NAME dataset\n\n"
dataset_name=${dataset#Dataset???_}

python nnunet_scripts/run_inference.py --path-dataset ${NNUNET_PATH}/nnUNet_raw/${DATASET_NAME}/imagesTs \
--path-out ${NNUNET_PATH}/nnUNet_results/${DATASET_NAME}/${dataset_name}_model_best_checkpoints_inference \
--path-out ${NNUNET_PATH}/nnUNet_results/${DATASET_NAME}/ensemble_models/${dataset_name}_model_best_checkpoints_inference \
--path-model ~/data/nnunet_all/nnUNet_results/${dataset}/nnUNetTrainer__nnUNetPlans__2d/ \
--use-gpu \
--use-best-checkpoint

python nnunet_scripts/run_evaluation.py --pred_path ${NNUNET_PATH}/nnUNet_results/${DATASET_NAME}/${dataset_name}_model_best_checkpoints_inference \
python nnunet_scripts/run_evaluation.py --pred_path ${NNUNET_PATH}/nnUNet_results/${DATASET_NAME}/ensemble_models/${dataset_name}_model_best_checkpoints_inference \
--mapping_path ${NNUNET_PATH}/nnUNet_raw/${DATASET_NAME}/fname_mapping.json \
--gt_path ${NNUNET_PATH}/test_labels \
--output_fname ${NNUNET_PATH}/nnUNet_results/${DATASET_NAME}/scores/${dataset_name}_model_best_checkpoints_scores \
--output_fname ${NNUNET_PATH}/nnUNet_results/${DATASET_NAME}/ensemble_models/scores/${dataset_name}_model_best_checkpoints_scores \
--pred_suffix _0000

# Run inference and evaluation for each fold
for fold in {0..4}; do
echo -e "\nProcessing fold $fold\n"
python nnunet_scripts/run_inference.py --path-dataset ${NNUNET_PATH}/nnUNet_raw/${DATASET_NAME}/imagesTs \
--path-out ${NNUNET_PATH}/nnUNet_results/${DATASET_NAME}/fold_{$fold}/${dataset_name}_model_best_checkpoints_inference \
--path-model ~/data/nnunet_all/nnUNet_results/${dataset}/nnUNetTrainer__nnUNetPlans__2d/ \
--folds $fold \
--use-gpu \
--use-best-checkpoint

python nnunet_scripts/run_evaluation.py --pred_path ${NNUNET_PATH}/nnUNet_results/${DATASET_NAME}/fold_{$fold}/${dataset_name}_model_best_checkpoints_inference \
--mapping_path ${NNUNET_PATH}/nnUNet_raw/${DATASET_NAME}/fname_mapping.json \
--gt_path ${NNUNET_PATH}/test_labels \
--output_fname ${NNUNET_PATH}/nnUNet_results/${DATASET_NAME}/fold_{$fold}/scores/${dataset_name}_model_best_checkpoints_scores \
--pred_suffix _0000
done
done
9 changes: 7 additions & 2 deletions nnunet_scripts/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,14 +7,17 @@

import argparse
import json
import os
import warnings
from pathlib import Path

import cv2
import numpy as np
import torch
import pandas as pd
from pathlib import Path
import torch
from monai.metrics import DiceMetric, MeanIoU


def compute_metrics(pred, gt, metric):
"""
Compute the given metric for a single image
Expand Down Expand Up @@ -109,6 +112,8 @@ def main():

# Export the DataFrame to a CSV file
output_fname = args.output_fname + '.csv'
output_dir = Path(output_fname).parent
os.makedirs(output_dir, exist_ok=True)
df.to_csv(output_fname, index=False)
print(f'Evaluation results saved to {output_fname}.')

Expand Down
20 changes: 11 additions & 9 deletions nnunet_scripts/run_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,13 @@
Author: Naga Karthik
"""

import os
import argparse
import torch
from pathlib import Path
from batchgenerators.utilities.file_and_folder_operations import join
import os
import time
from pathlib import Path

import torch
from batchgenerators.utilities.file_and_folder_operations import join
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor

if 'nnUNet_raw' not in os.environ:
Expand Down Expand Up @@ -38,11 +38,10 @@ def get_parser() -> argparse.ArgumentParser:
parser.add_argument('--path-model', required=True,
help='Path to the model directory. This folder should contain individual folders '
'like fold_0, fold_1, etc.',)
parser.add_argument('--folds', nargs='+', type=int, default=None,
help='List of folds to use for inference. If not specified, all available folds. Default: None')
parser.add_argument('--use-gpu', action='store_true', default=False,
help='Use GPU for inference. Default: False')
parser.add_argument('--use-mirroring', action='store_true', default=False,
help='Use mirroring (test-time) augmentation for prediction. '
'NOTE: Inference takes a long time when this is enabled. Default: False')
parser.add_argument('--use-best-checkpoint', action='store_true', default=False,
help='Use the best checkpoint (instead of the final checkpoint) for prediction. '
'NOTE: nnUNet by default uses the final checkpoint. Default: False')
Expand Down Expand Up @@ -136,15 +135,18 @@ def main():
path_pred = os.path.join(args.path_out, add_suffix(fname, '_pred'))
path_out.append(path_pred)

folds_avail = [int(f.split('_')[-1]) for f in os.listdir(args.path_model) if f.startswith('fold_')]
if args.folds is not None:
folds_avail = args.folds
else:
folds_avail = [int(f.split('_')[-1]) for f in os.listdir(args.path_model) if f.startswith('fold_')]

print('Starting inference...')
start = time.time()
predictor = nnUNetPredictor(
tile_step_size=0.5,
use_gaussian=True,
use_mirroring=True,
perform_everything_on_gpu=True if args.use_gpu else False,
perform_everything_on_gpu=args.use_gpu,
device=torch.device('cuda') if args.use_gpu else torch.device('cpu'),
verbose=False,
verbose_preprocessing=False,
Expand Down

0 comments on commit 3bcc401

Please sign in to comment.