Skip to content

Commit

Permalink
Merge pull request #18 from axondeepseg/feature/add_ensemble_experiments
Browse files Browse the repository at this point in the history
 Ensembling Experiment Support Added
  • Loading branch information
ArthurBoschet authored Apr 5, 2024
2 parents b35c85b + 5b4d57c commit 95c5936
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 8 deletions.
10 changes: 8 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -81,13 +81,19 @@ Note: `<FORMATTED_DATASET_ID>` should be a three-digit number where 1 would beco
To replicate the inference experiments, execute the following script:

```bash
source ./nnunet_scripts/inference_and_evaluation.sh ${RESULTS_DIR}/nnUNet_results <DATASET_1> <DATASET_2> <DATASET_3> ... <DATASET_N>
source ./nnunet_scripts/inference_and_evaluation.sh ${NNUNET_DIR} <DATASET_1> <DATASET_2> <DATASET_3> ... <DATASET_N>
```

For instance, to run the script with specific datasets, use the command below:

```bash
source ./nnunet_scripts/inference_and_evaluation.sh ${RESULTS_DIR}/nnUNet_results Dataset002_SEM Dataset003_TEM Dataset004_BF_RAT Dataset005_wakehealth Dataset006_BF_VCU Dataset444_AGG
source ./nnunet_scripts/inference_and_evaluation.sh ${NNUNET_DIR} Dataset002_SEM Dataset003_TEM Dataset004_BF_RAT Dataset005_wakehealth Dataset006_BF_VCU Dataset444_AGG
```
In addition to the individual inference and evaluation scripts, there is an "ensemble_inference_and_evaluation.sh" script available. This script performs ensemble inferences using all the models listed and then evaluates the ensemble model. The arguments for this script are similar to the ones mentioned above, except `<DATASET_K>` represents all the models being ensembled.

To use the ensemble script, execute the following command:
```bash
source ./nnunet_scripts/ensemble_inference_and_evaluation.sh ${NNUNET_DIR} <DATASET_1> <DATASET_2> <DATASET_3> ... <DATASET_N>
```

To replicate out of distribution experiments (OOD), you can use the following script:
Expand Down
36 changes: 36 additions & 0 deletions nnunet_scripts/ensemble_inference_and_evaluation.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
# The first argument should be the path to the nnunet folder.
# The remaining arguments should be the names of the datasets to process.
NNUNET_PATH=$1
DATASET_NAME="Dataset444_AGG"

shift
for dataset in "$@"; do
echo -e "\n\nRun inference with ${NNUNET_PATH}/nnUNet_results/${dataset} model on $DATASET_NAME dataset\n\n"
dataset_name=${dataset#Dataset???_}

# Run inference and save the probabilities
python nnunet_scripts/run_inference.py --path-dataset ${NNUNET_PATH}/nnUNet_raw/${DATASET_NAME}/imagesTs \
--path-out ${NNUNET_PATH}/nnUNet_results/${DATASET_NAME}/ensemble_models/${dataset_name}_model_best_checkpoints_inference_with_probabilities \
--path-model ${NNUNET_PATH}/nnUNet_results/${dataset}/nnUNetTrainer__nnUNetPlans__2d/ \
--use-gpu \
--use-best-checkpoint \
--save-probabilities
done

folders=()
ensemble_name="ensemble"
for dataset in "$@"; do
dataset_name=${dataset#Dataset???_}
folders+=("${NNUNET_PATH}/nnUNet_results/${DATASET_NAME}/ensemble_models/${dataset_name}_model_best_checkpoints_inference_with_probabilities")
ensemble_name+="_${dataset_name}"
done

# Ensemble the inference results with nnUNetv2_ensemble
nnUNetv2_ensemble -i ${folders[@]} -o ${NNUNET_PATH}/nnUNet_results/${DATASET_NAME}/ensemble_models/${ensemble_name}_model_best_checkpoints_inference

# Evaluate the ensemble model
python nnunet_scripts/run_evaluation.py --pred_path ${NNUNET_PATH}/nnUNet_results/${DATASET_NAME}/ensemble_models/${ensemble_name}_model_best_checkpoints_inference \
--mapping_path ${NNUNET_PATH}/nnUNet_raw/${DATASET_NAME}/fname_mapping.json \
--gt_path ${NNUNET_PATH}/test_labels \
--output_fname ${NNUNET_PATH}/nnUNet_results/${DATASET_NAME}/ensemble_models/scores/${ensemble_name}_model_best_checkpoints_scores \
--pred_suffix _0000
20 changes: 14 additions & 6 deletions nnunet_scripts/run_evaluation.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,24 +52,32 @@ def get_original_filename(gt, reverted_mapping):

def print_metric(value, gt, metric, label, reverted_mapping):
# modify gt name to add _0000 suffix before file extension
gt_name = gt.name.split('.')[0] + '_0000.' + gt.name.split('.')[1]
original_fname = get_original_filename(gt, reverted_mapping)
if reverted_mapping:
original_fname = get_original_filename(gt, reverted_mapping)
else:
original_fname = None
metric_name = metric.__class__.__name__
print(f'{metric_name} for {label} in {gt.name} (aka {original_fname}): {value}')

def main():
parser = argparse.ArgumentParser(description='Run evaluation on a generalist model')
parser.add_argument('-p', '--pred_path', type=str, help='Path to the predictions folder (axonmyelin preds)')
parser.add_argument('-m', '--mapping_path', type=str, help='Path to the filename mapping JSON file')
parser.add_argument('-m', '--mapping_path', type=str, default=None, help='Path to the filename mapping JSON file')
parser.add_argument('-g', '--gt_path', type=str, help='Path to the GT folder (axonmyelin masks)')
parser.add_argument('-o', '--output_fname', type=str, help='Filename for evaluation results')
parser.add_argument('-s', '--pred_suffix', type=str, default="", help='Suffix in the prediction files (e.g. _0000)')
args = parser.parse_args()

pred_path = Path(args.pred_path)
gt_path = Path(args.gt_path)
mapping = json.load(open(args.mapping_path))
reverted_mapping = {v: k for k, v in mapping['images_ts'].items()}

# Load the filename mapping
if args.mapping_path:
mapping = json.load(open(args.mapping_path))
reverted_mapping = {v: k for k, v in mapping['images_ts'].items()}
else:
reverted_mapping = None

# there might be more test imgs than GTs; evaluation on labelled data only
gts = [f for f in gt_path.glob('*.png')]

Expand Down Expand Up @@ -100,7 +108,7 @@ def main():
# compute the metrics
for label, pred_mask, gt_mask in classwise_pairs:
row = {
'original_fname': get_original_filename(gt, reverted_mapping),
'original_fname': get_original_filename(gt, reverted_mapping) if reverted_mapping else pred.name,
'pred_fname': pred.name,
'label': label
}
Expand Down

0 comments on commit 95c5936

Please sign in to comment.