From 44d0ff6ea79c5ae5aa2ee23406b890c320f84ef0 Mon Sep 17 00:00:00 2001 From: "Christopher J. Markiewicz" Date: Thu, 19 Dec 2024 10:19:48 -0500 Subject: [PATCH] type: Annotate nireports.tools --- nireports/tools/ndimage.py | 33 +++++++++++++++---- nireports/tools/timeseries.py | 60 ++++++++++++++++++++--------------- 2 files changed, 62 insertions(+), 31 deletions(-) diff --git a/nireports/tools/ndimage.py b/nireports/tools/ndimage.py index c185ece9..a6035108 100644 --- a/nireports/tools/ndimage.py +++ b/nireports/tools/ndimage.py @@ -27,21 +27,32 @@ # niworkflows/utils/images.py """Tooling to manipulate n-dimensional images.""" +import os +import typing as ty + import nibabel as nb import numpy as np +import numpy.typing as npt +from nibabel.spatialimages import SpatialImage + +ImgT = ty.TypeVar("ImgT", bound=nb.filebasedimages.FileBasedImage) +SpatImgT = ty.TypeVar("SpatImgT", bound=SpatialImage) +Mat = npt.NDArray[np.float64] -def rotation2canonical(img): +def rotation2canonical(img: SpatialImage) -> Mat | None: """Calculate the rotation w.r.t. cardinal axes of input image.""" img = nb.as_closest_canonical(img) + # XXX: SpatialImage.affine needs to be typed + affine: Mat = img.affine newaff = np.diag(img.header.get_zooms()[:3]) - r = newaff @ np.linalg.pinv(img.affine[:3, :3]) + r = newaff @ np.linalg.pinv(affine[:3, :3]) if np.allclose(r, np.eye(3)): return None return r -def rotate_affine(img, rot=None): +def rotate_affine(img: SpatImgT, rot: Mat | None = None) -> SpatImgT: """Rewrite the affine of a spatial image.""" if rot is None: return img @@ -52,11 +63,21 @@ def rotate_affine(img, rot=None): return img.__class__(img.dataobj, affine, img.header) -def _get_values_inside_a_mask(main_file, mask_file): - main_nii = nb.load(main_file) +def load_api(path: str | os.PathLike[str], api: type[ImgT]) -> ImgT: + img = nb.load(path) + if not isinstance(img, api): + raise TypeError(f"File {path} does not implement {api} interface") + return img + + +def _get_values_inside_a_mask( + main_file: str | os.PathLike[str], + mask_file: str | os.PathLike[str], +) -> npt.NDArray[np.float64]: + main_nii = load_api(main_file, SpatialImage) main_data = main_nii.get_fdata() nan_mask = np.logical_not(np.isnan(main_data)) - mask = nb.load(mask_file).get_fdata() > 0 + mask = load_api(mask_file, SpatialImage).get_fdata() > 0 data = main_data[np.logical_and(nan_mask, mask)] return data diff --git a/nireports/tools/timeseries.py b/nireports/tools/timeseries.py index 0a5bf468..eb42f8e6 100644 --- a/nireports/tools/timeseries.py +++ b/nireports/tools/timeseries.py @@ -27,11 +27,16 @@ # niworkflows/utils/timeseries.py """Extracting signals from NIfTI and CIFTI2 files.""" +from collections.abc import Sequence + import nibabel as nb import numpy as np +import numpy.typing as npt + +from .ndimage import load_api -def get_tr(img): +def get_tr(img: nb.Nifti1Image | nb.Cifti2Image) -> float: """ Attempt to extract repetition time from NIfTI/CIFTI header. @@ -49,17 +54,18 @@ def get_tr(img): 2.0 """ - - try: + if isinstance(img, nb.Cifti2Image): return img.header.matrix.get_index_map(0).series_step - except AttributeError: + else: return img.header.get_zooms()[-1] raise RuntimeError("Could not extract TR - unknown data structure type") -def cifti_timeseries(dataset): +def cifti_timeseries( + dataset: str | nb.Cifti2Image, +) -> tuple[npt.NDArray[np.float32], dict[str, list[int]]]: """Extract timeseries from CIFTI2 dataset.""" - dataset = nb.load(dataset) if isinstance(dataset, str) else dataset + dataset = load_api(dataset, nb.Cifti2Image) if isinstance(dataset, str) else dataset if dataset.nifti_header.get_intent()[0] != "ConnDenseSeries": raise ValueError("Not a dense timeseries") @@ -71,33 +77,37 @@ def cifti_timeseries(dataset): "CIFTI_STRUCTURE_CEREBELLUM_LEFT": "CbL", "CIFTI_STRUCTURE_CEREBELLUM_RIGHT": "CbR", } - seg = {label: [] for label in list(labels.values()) + ["Other"]} + seg: dict[str, list[int]] = {label: [] for label in list(labels.values()) + ["Other"]} for bm in matrix.get_index_map(1).brain_models: label = labels.get(bm.brain_structure, "Other") seg[label] += list(range(bm.index_offset, bm.index_offset + bm.index_count)) - return dataset.get_fdata(dtype="float32").T, seg + return dataset.get_fdata(dtype=np.float32).T, seg def nifti_timeseries( - dataset, - segmentation=None, - labels=("Ctx GM", "dGM", "WM+CSF", "Cb", "Crown"), - remap_rois=False, - lut=None, -): + dataset: str | nb.Nifti1Image, + segmentation: str | nb.Nifti1Image | None = None, + labels: Sequence[str] = ("Ctx GM", "dGM", "WM+CSF", "Cb", "Crown"), + remap_rois: bool = False, + lut: npt.NDArray[np.uint8] | None = None, +) -> tuple[npt.NDArray[np.float32], dict[str, list[int]] | None]: """Extract timeseries from NIfTI1/2 datasets.""" - dataset = nb.load(dataset) if isinstance(dataset, str) else dataset - data = dataset.get_fdata(dtype="float32").reshape((-1, dataset.shape[-1])) + dataset = load_api(dataset, nb.Nifti1Image) if isinstance(dataset, str) else dataset + data: npt.NDArray[np.float32] = dataset.get_fdata(dtype="float32").reshape( + (-1, dataset.shape[-1]) + ) if segmentation is None: return data, None # Open NIfTI and extract numpy array - segmentation = nb.load(segmentation) if isinstance(segmentation, str) else segmentation - segmentation = np.asanyarray(segmentation.dataobj, dtype=int).reshape(-1) + segmentation = ( + load_api(segmentation, nb.Nifti1Image) if isinstance(segmentation, str) else segmentation + ) + seg_data = np.asanyarray(segmentation.dataobj, dtype=int).reshape(-1) - remap_rois = remap_rois or (len(np.unique(segmentation[segmentation > 0])) > len(labels)) + remap_rois = remap_rois or (len(np.unique(seg_data[seg_data > 0])) > len(labels)) # Map segmentation if remap_rois or lut is not None: @@ -108,12 +118,12 @@ def nifti_timeseries( lut[1:11] = 3 # WM+CSF lut[255] = 4 # Cerebellum # Apply lookup table - segmentation = lut[segmentation] + seg_data = lut[seg_data] - fgmask = segmentation > 0 - segmentation = segmentation[fgmask] - seg_dict = {} - for i in np.unique(segmentation): - seg_dict[labels[i - 1]] = np.argwhere(segmentation == i).squeeze() + fgmask = seg_data > 0 + seg_values = seg_data[fgmask] + seg_dict: dict[str, list[int]] = {} + for i in np.unique(seg_values): + seg_dict[labels[i - 1]] = list(np.argwhere(seg_values == i).squeeze()) return data[fgmask], seg_dict