Skip to content

Commit

Permalink
type: Annotate nireports.tools
Browse files Browse the repository at this point in the history
  • Loading branch information
effigies committed Dec 19, 2024
1 parent f0690a1 commit 44d0ff6
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 31 deletions.
33 changes: 27 additions & 6 deletions nireports/tools/ndimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,32 @@
# niworkflows/utils/images.py
"""Tooling to manipulate n-dimensional images."""

import os
import typing as ty

import nibabel as nb
import numpy as np
import numpy.typing as npt
from nibabel.spatialimages import SpatialImage

ImgT = ty.TypeVar("ImgT", bound=nb.filebasedimages.FileBasedImage)
SpatImgT = ty.TypeVar("SpatImgT", bound=SpatialImage)
Mat = npt.NDArray[np.float64]


def rotation2canonical(img):
def rotation2canonical(img: SpatialImage) -> Mat | None:
"""Calculate the rotation w.r.t. cardinal axes of input image."""
img = nb.as_closest_canonical(img)
# XXX: SpatialImage.affine needs to be typed
affine: Mat = img.affine
newaff = np.diag(img.header.get_zooms()[:3])
r = newaff @ np.linalg.pinv(img.affine[:3, :3])
r = newaff @ np.linalg.pinv(affine[:3, :3])
if np.allclose(r, np.eye(3)):
return None
return r


def rotate_affine(img, rot=None):
def rotate_affine(img: SpatImgT, rot: Mat | None = None) -> SpatImgT:
"""Rewrite the affine of a spatial image."""
if rot is None:
return img
Expand All @@ -52,11 +63,21 @@ def rotate_affine(img, rot=None):
return img.__class__(img.dataobj, affine, img.header)


def _get_values_inside_a_mask(main_file, mask_file):
main_nii = nb.load(main_file)
def load_api(path: str | os.PathLike[str], api: type[ImgT]) -> ImgT:
img = nb.load(path)
if not isinstance(img, api):
raise TypeError(f"File {path} does not implement {api} interface")
return img


def _get_values_inside_a_mask(
main_file: str | os.PathLike[str],
mask_file: str | os.PathLike[str],
) -> npt.NDArray[np.float64]:
main_nii = load_api(main_file, SpatialImage)
main_data = main_nii.get_fdata()
nan_mask = np.logical_not(np.isnan(main_data))
mask = nb.load(mask_file).get_fdata() > 0
mask = load_api(mask_file, SpatialImage).get_fdata() > 0

data = main_data[np.logical_and(nan_mask, mask)]
return data
60 changes: 35 additions & 25 deletions nireports/tools/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,16 @@
# niworkflows/utils/timeseries.py
"""Extracting signals from NIfTI and CIFTI2 files."""

from collections.abc import Sequence

import nibabel as nb
import numpy as np
import numpy.typing as npt

from .ndimage import load_api


def get_tr(img):
def get_tr(img: nb.Nifti1Image | nb.Cifti2Image) -> float:
"""
Attempt to extract repetition time from NIfTI/CIFTI header.
Expand All @@ -49,17 +54,18 @@ def get_tr(img):
2.0
"""

try:
if isinstance(img, nb.Cifti2Image):
return img.header.matrix.get_index_map(0).series_step
except AttributeError:
else:
return img.header.get_zooms()[-1]
raise RuntimeError("Could not extract TR - unknown data structure type")


def cifti_timeseries(dataset):
def cifti_timeseries(
dataset: str | nb.Cifti2Image,
) -> tuple[npt.NDArray[np.float32], dict[str, list[int]]]:
"""Extract timeseries from CIFTI2 dataset."""
dataset = nb.load(dataset) if isinstance(dataset, str) else dataset
dataset = load_api(dataset, nb.Cifti2Image) if isinstance(dataset, str) else dataset

if dataset.nifti_header.get_intent()[0] != "ConnDenseSeries":
raise ValueError("Not a dense timeseries")
Expand All @@ -71,33 +77,37 @@ def cifti_timeseries(dataset):
"CIFTI_STRUCTURE_CEREBELLUM_LEFT": "CbL",
"CIFTI_STRUCTURE_CEREBELLUM_RIGHT": "CbR",
}
seg = {label: [] for label in list(labels.values()) + ["Other"]}
seg: dict[str, list[int]] = {label: [] for label in list(labels.values()) + ["Other"]}
for bm in matrix.get_index_map(1).brain_models:
label = labels.get(bm.brain_structure, "Other")
seg[label] += list(range(bm.index_offset, bm.index_offset + bm.index_count))

return dataset.get_fdata(dtype="float32").T, seg
return dataset.get_fdata(dtype=np.float32).T, seg


def nifti_timeseries(
dataset,
segmentation=None,
labels=("Ctx GM", "dGM", "WM+CSF", "Cb", "Crown"),
remap_rois=False,
lut=None,
):
dataset: str | nb.Nifti1Image,
segmentation: str | nb.Nifti1Image | None = None,
labels: Sequence[str] = ("Ctx GM", "dGM", "WM+CSF", "Cb", "Crown"),
remap_rois: bool = False,
lut: npt.NDArray[np.uint8] | None = None,
) -> tuple[npt.NDArray[np.float32], dict[str, list[int]] | None]:
"""Extract timeseries from NIfTI1/2 datasets."""
dataset = nb.load(dataset) if isinstance(dataset, str) else dataset
data = dataset.get_fdata(dtype="float32").reshape((-1, dataset.shape[-1]))
dataset = load_api(dataset, nb.Nifti1Image) if isinstance(dataset, str) else dataset
data: npt.NDArray[np.float32] = dataset.get_fdata(dtype="float32").reshape(
(-1, dataset.shape[-1])
)

if segmentation is None:
return data, None

# Open NIfTI and extract numpy array
segmentation = nb.load(segmentation) if isinstance(segmentation, str) else segmentation
segmentation = np.asanyarray(segmentation.dataobj, dtype=int).reshape(-1)
segmentation = (
load_api(segmentation, nb.Nifti1Image) if isinstance(segmentation, str) else segmentation
)
seg_data = np.asanyarray(segmentation.dataobj, dtype=int).reshape(-1)

remap_rois = remap_rois or (len(np.unique(segmentation[segmentation > 0])) > len(labels))
remap_rois = remap_rois or (len(np.unique(seg_data[seg_data > 0])) > len(labels))

# Map segmentation
if remap_rois or lut is not None:
Expand All @@ -108,12 +118,12 @@ def nifti_timeseries(
lut[1:11] = 3 # WM+CSF
lut[255] = 4 # Cerebellum
# Apply lookup table
segmentation = lut[segmentation]
seg_data = lut[seg_data]

fgmask = segmentation > 0
segmentation = segmentation[fgmask]
seg_dict = {}
for i in np.unique(segmentation):
seg_dict[labels[i - 1]] = np.argwhere(segmentation == i).squeeze()
fgmask = seg_data > 0
seg_values = seg_data[fgmask]
seg_dict: dict[str, list[int]] = {}
for i in np.unique(seg_values):
seg_dict[labels[i - 1]] = list(np.argwhere(seg_values == i).squeeze())

return data[fgmask], seg_dict

0 comments on commit 44d0ff6

Please sign in to comment.