From e1f0beb389c749c1c4fa6805ab15c51dea541f7f Mon Sep 17 00:00:00 2001 From: Nicolas Kaenzig Date: Tue, 19 Nov 2024 09:39:39 +0000 Subject: [PATCH] fix orientation --- .../segmentation/total_segmentator_2d.py | 27 ++++++++++++------- 1 file changed, 17 insertions(+), 10 deletions(-) diff --git a/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py b/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py index 9707cd60..53cc0c5f 100644 --- a/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py +++ b/src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py @@ -210,30 +210,31 @@ def load_image(self, index: int) -> tv_tensors.Image: sample_index, slice_index = self._indices[index] image_path = self._get_image_path(sample_index) image_array = io.read_nifti(image_path, slice_index) - image_rgb_array = image_array.repeat(3, axis=2) - return tv_tensors.Image(image_rgb_array.transpose(2, 0, 1)) + image_array = self._fix_orientation(image_array) + return tv_tensors.Image(image_array.copy().transpose(2, 0, 1)) @override def load_mask(self, index: int) -> tv_tensors.Mask: if self._optimize_mask_loading: - return self._load_semantic_label_mask(index) - return self._load_mask(index) + mask = self._load_semantic_label_mask(index) + else: + mask = self._load_mask(index) + mask = self._fix_orientation(mask) + return tv_tensors.Mask(mask.copy().squeeze(), dtype=torch.int64) # type: ignore @override def load_metadata(self, index: int) -> Dict[str, Any]: _, slice_index = self._indices[index] return {"slice_index": slice_index} - def _load_mask(self, index: int) -> tv_tensors.Mask: + def _load_mask(self, index: int) -> npt.NDArray[Any]: sample_index, slice_index = self._indices[index] - semantic_labels = self._load_masks_as_semantic_label(sample_index, slice_index) - return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue] + return self._load_masks_as_semantic_label(sample_index, slice_index) - def _load_semantic_label_mask(self, index: int) -> tv_tensors.Mask: + def _load_semantic_label_mask(self, index: int) -> npt.NDArray[Any]: """Loads the segmentation mask from a semantic label NifTi file.""" sample_index, slice_index = self._indices[index] - semantic_labels = io.read_nifti(self._get_optimized_masks_file(sample_index), slice_index) - return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue] + return io.read_nifti(self._get_optimized_masks_file(sample_index), slice_index) def _load_masks_as_semantic_label( self, sample_index: int, slice_index: int | None = None @@ -298,6 +299,12 @@ def _process_mask(sample_index: Any, filename: str) -> None: with open(mask_classes_file, "w") as file: file.write(str(self.classes)) + def _fix_orientation(self, array: npt.NDArray): + """Fixes orientation such that table is at the bottom & liver on the left.""" + array = np.rot90(array) + array = np.flip(array, axis=1) + return array + def _get_image_path(self, sample_index: int) -> str: """Returns the corresponding image path.""" sample_dir = self._samples_dirs[sample_index]