diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py b/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py index 7903a06f..20fc2a1b 100644 --- a/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py +++ b/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py @@ -184,21 +184,21 @@ def validate(self) -> None: return if self._classes: - first_and_last_labels = (self._classes[0], self._classes[-1]) + last_label = self._classes[-1] n_classes = len(self._classes) elif self._class_mappings: classes = sorted(set(self._class_mappings.values())) - first_and_last_labels = (classes[0], classes[-1]) + last_label = classes[-1] n_classes = len(classes) else: - first_and_last_labels = ("adrenal_gland_left", "vertebrae_T9") + last_label = "vertebrae_T9" n_classes = 117 _validators.check_dataset_integrity( self, length=self._expected_dataset_lengths.get(f"{self._split}_{self._version}", 0), - n_classes=n_classes+1, - first_and_last_labels=first_and_last_labels, + n_classes=n_classes + 1, + first_and_last_labels=("background", last_label), ) @override @@ -250,11 +250,9 @@ def _load_masks_as_semantic_label( binary_masks = [io.read_nifti(path, slice_index) for path in mask_paths] if self._class_mappings: - mapped_binary_masks = [np.zeros_like(binary_masks[0], dtype=np.bool_)] * len( - self.classes - ) + mapped_binary_masks = [np.zeros_like(binary_masks[0], dtype=np.bool_)] * len(self.classes[1:]) for original_class, mapped_class in self._class_mappings.items(): - mapped_index = self.class_to_idx[mapped_class] + mapped_index = self.class_to_idx[mapped_class] - 1 original_index = list(self._class_mappings.keys()).index(original_class) mapped_binary_masks[mapped_index] = np.logical_or( mapped_binary_masks[mapped_index], binary_masks[original_index]