Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaenzig committed Nov 14, 2024
1 parent f382733 commit 6804525
Showing 1 changed file with 7 additions and 9 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -184,21 +184,21 @@ def validate(self) -> None:
return

if self._classes:
first_and_last_labels = (self._classes[0], self._classes[-1])
last_label = self._classes[-1]
n_classes = len(self._classes)
elif self._class_mappings:
classes = sorted(set(self._class_mappings.values()))
first_and_last_labels = (classes[0], classes[-1])
last_label = classes[-1]
n_classes = len(classes)
else:
first_and_last_labels = ("adrenal_gland_left", "vertebrae_T9")
last_label = "vertebrae_T9"
n_classes = 117

_validators.check_dataset_integrity(
self,
length=self._expected_dataset_lengths.get(f"{self._split}_{self._version}", 0),
n_classes=n_classes+1,
first_and_last_labels=first_and_last_labels,
n_classes=n_classes + 1,
first_and_last_labels=("background", last_label),
)

@override
Expand Down Expand Up @@ -250,11 +250,9 @@ def _load_masks_as_semantic_label(
binary_masks = [io.read_nifti(path, slice_index) for path in mask_paths]

if self._class_mappings:
mapped_binary_masks = [np.zeros_like(binary_masks[0], dtype=np.bool_)] * len(
self.classes
)
mapped_binary_masks = [np.zeros_like(binary_masks[0], dtype=np.bool_)] * len(self.classes[1:])
for original_class, mapped_class in self._class_mappings.items():
mapped_index = self.class_to_idx[mapped_class]
mapped_index = self.class_to_idx[mapped_class] - 1
original_index = list(self._class_mappings.keys()).index(original_class)
mapped_binary_masks[mapped_index] = np.logical_or(
mapped_binary_masks[mapped_index], binary_masks[original_index]
Expand Down

0 comments on commit 6804525

Please sign in to comment.