Skip to content

Commit

Permalink
fix orientation
Browse files Browse the repository at this point in the history
  • Loading branch information
nkaenzig committed Nov 19, 2024
1 parent 5325a36 commit e1f0beb
Showing 1 changed file with 17 additions and 10 deletions.
27 changes: 17 additions & 10 deletions src/eva/vision/data/datasets/segmentation/total_segmentator_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,30 +210,31 @@ def load_image(self, index: int) -> tv_tensors.Image:
sample_index, slice_index = self._indices[index]
image_path = self._get_image_path(sample_index)
image_array = io.read_nifti(image_path, slice_index)
image_rgb_array = image_array.repeat(3, axis=2)
return tv_tensors.Image(image_rgb_array.transpose(2, 0, 1))
image_array = self._fix_orientation(image_array)
return tv_tensors.Image(image_array.copy().transpose(2, 0, 1))

@override
def load_mask(self, index: int) -> tv_tensors.Mask:
if self._optimize_mask_loading:
return self._load_semantic_label_mask(index)
return self._load_mask(index)
mask = self._load_semantic_label_mask(index)
else:
mask = self._load_mask(index)
mask = self._fix_orientation(mask)
return tv_tensors.Mask(mask.copy().squeeze(), dtype=torch.int64) # type: ignore

@override
def load_metadata(self, index: int) -> Dict[str, Any]:
_, slice_index = self._indices[index]
return {"slice_index": slice_index}

def _load_mask(self, index: int) -> tv_tensors.Mask:
def _load_mask(self, index: int) -> npt.NDArray[Any]:
sample_index, slice_index = self._indices[index]
semantic_labels = self._load_masks_as_semantic_label(sample_index, slice_index)
return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
return self._load_masks_as_semantic_label(sample_index, slice_index)

def _load_semantic_label_mask(self, index: int) -> tv_tensors.Mask:
def _load_semantic_label_mask(self, index: int) -> npt.NDArray[Any]:
"""Loads the segmentation mask from a semantic label NifTi file."""
sample_index, slice_index = self._indices[index]
semantic_labels = io.read_nifti(self._get_optimized_masks_file(sample_index), slice_index)
return tv_tensors.Mask(semantic_labels.squeeze(), dtype=torch.int64) # type: ignore[reportCallIssue]
return io.read_nifti(self._get_optimized_masks_file(sample_index), slice_index)

def _load_masks_as_semantic_label(
self, sample_index: int, slice_index: int | None = None
Expand Down Expand Up @@ -298,6 +299,12 @@ def _process_mask(sample_index: Any, filename: str) -> None:
with open(mask_classes_file, "w") as file:
file.write(str(self.classes))

def _fix_orientation(self, array: npt.NDArray):
"""Fixes orientation such that table is at the bottom & liver on the left."""
array = np.rot90(array)
array = np.flip(array, axis=1)
return array

def _get_image_path(self, sample_index: int) -> str:
"""Returns the corresponding image path."""
sample_dir = self._samples_dirs[sample_index]
Expand Down

0 comments on commit e1f0beb

Please sign in to comment.