From 44d0ff6ea79c5ae5aa2ee23406b890c320f84ef0 Mon Sep 17 00:00:00 2001 From: "Christopher J. Markiewicz" Date: Thu, 19 Dec 2024 10:19:48 -0500 Subject: [PATCH 1/2] type: Annotate nireports.tools --- nireports/tools/ndimage.py | 33 +++++++++++++++---- nireports/tools/timeseries.py | 60 ++++++++++++++++++++--------------- 2 files changed, 62 insertions(+), 31 deletions(-) diff --git a/nireports/tools/ndimage.py b/nireports/tools/ndimage.py index c185ece9..a6035108 100644 --- a/nireports/tools/ndimage.py +++ b/nireports/tools/ndimage.py @@ -27,21 +27,32 @@ # niworkflows/utils/images.py """Tooling to manipulate n-dimensional images.""" +import os +import typing as ty + import nibabel as nb import numpy as np +import numpy.typing as npt +from nibabel.spatialimages import SpatialImage + +ImgT = ty.TypeVar("ImgT", bound=nb.filebasedimages.FileBasedImage) +SpatImgT = ty.TypeVar("SpatImgT", bound=SpatialImage) +Mat = npt.NDArray[np.float64] -def rotation2canonical(img): +def rotation2canonical(img: SpatialImage) -> Mat | None: """Calculate the rotation w.r.t. cardinal axes of input image.""" img = nb.as_closest_canonical(img) + # XXX: SpatialImage.affine needs to be typed + affine: Mat = img.affine newaff = np.diag(img.header.get_zooms()[:3]) - r = newaff @ np.linalg.pinv(img.affine[:3, :3]) + r = newaff @ np.linalg.pinv(affine[:3, :3]) if np.allclose(r, np.eye(3)): return None return r -def rotate_affine(img, rot=None): +def rotate_affine(img: SpatImgT, rot: Mat | None = None) -> SpatImgT: """Rewrite the affine of a spatial image.""" if rot is None: return img @@ -52,11 +63,21 @@ def rotate_affine(img, rot=None): return img.__class__(img.dataobj, affine, img.header) -def _get_values_inside_a_mask(main_file, mask_file): - main_nii = nb.load(main_file) +def load_api(path: str | os.PathLike[str], api: type[ImgT]) -> ImgT: + img = nb.load(path) + if not isinstance(img, api): + raise TypeError(f"File {path} does not implement {api} interface") + return img + + +def _get_values_inside_a_mask( + main_file: str | os.PathLike[str], + mask_file: str | os.PathLike[str], +) -> npt.NDArray[np.float64]: + main_nii = load_api(main_file, SpatialImage) main_data = main_nii.get_fdata() nan_mask = np.logical_not(np.isnan(main_data)) - mask = nb.load(mask_file).get_fdata() > 0 + mask = load_api(mask_file, SpatialImage).get_fdata() > 0 data = main_data[np.logical_and(nan_mask, mask)] return data diff --git a/nireports/tools/timeseries.py b/nireports/tools/timeseries.py index 0a5bf468..eb42f8e6 100644 --- a/nireports/tools/timeseries.py +++ b/nireports/tools/timeseries.py @@ -27,11 +27,16 @@ # niworkflows/utils/timeseries.py """Extracting signals from NIfTI and CIFTI2 files.""" +from collections.abc import Sequence + import nibabel as nb import numpy as np +import numpy.typing as npt + +from .ndimage import load_api -def get_tr(img): +def get_tr(img: nb.Nifti1Image | nb.Cifti2Image) -> float: """ Attempt to extract repetition time from NIfTI/CIFTI header. @@ -49,17 +54,18 @@ def get_tr(img): 2.0 """ - - try: + if isinstance(img, nb.Cifti2Image): return img.header.matrix.get_index_map(0).series_step - except AttributeError: + else: return img.header.get_zooms()[-1] raise RuntimeError("Could not extract TR - unknown data structure type") -def cifti_timeseries(dataset): +def cifti_timeseries( + dataset: str | nb.Cifti2Image, +) -> tuple[npt.NDArray[np.float32], dict[str, list[int]]]: """Extract timeseries from CIFTI2 dataset.""" - dataset = nb.load(dataset) if isinstance(dataset, str) else dataset + dataset = load_api(dataset, nb.Cifti2Image) if isinstance(dataset, str) else dataset if dataset.nifti_header.get_intent()[0] != "ConnDenseSeries": raise ValueError("Not a dense timeseries") @@ -71,33 +77,37 @@ def cifti_timeseries(dataset): "CIFTI_STRUCTURE_CEREBELLUM_LEFT": "CbL", "CIFTI_STRUCTURE_CEREBELLUM_RIGHT": "CbR", } - seg = {label: [] for label in list(labels.values()) + ["Other"]} + seg: dict[str, list[int]] = {label: [] for label in list(labels.values()) + ["Other"]} for bm in matrix.get_index_map(1).brain_models: label = labels.get(bm.brain_structure, "Other") seg[label] += list(range(bm.index_offset, bm.index_offset + bm.index_count)) - return dataset.get_fdata(dtype="float32").T, seg + return dataset.get_fdata(dtype=np.float32).T, seg def nifti_timeseries( - dataset, - segmentation=None, - labels=("Ctx GM", "dGM", "WM+CSF", "Cb", "Crown"), - remap_rois=False, - lut=None, -): + dataset: str | nb.Nifti1Image, + segmentation: str | nb.Nifti1Image | None = None, + labels: Sequence[str] = ("Ctx GM", "dGM", "WM+CSF", "Cb", "Crown"), + remap_rois: bool = False, + lut: npt.NDArray[np.uint8] | None = None, +) -> tuple[npt.NDArray[np.float32], dict[str, list[int]] | None]: """Extract timeseries from NIfTI1/2 datasets.""" - dataset = nb.load(dataset) if isinstance(dataset, str) else dataset - data = dataset.get_fdata(dtype="float32").reshape((-1, dataset.shape[-1])) + dataset = load_api(dataset, nb.Nifti1Image) if isinstance(dataset, str) else dataset + data: npt.NDArray[np.float32] = dataset.get_fdata(dtype="float32").reshape( + (-1, dataset.shape[-1]) + ) if segmentation is None: return data, None # Open NIfTI and extract numpy array - segmentation = nb.load(segmentation) if isinstance(segmentation, str) else segmentation - segmentation = np.asanyarray(segmentation.dataobj, dtype=int).reshape(-1) + segmentation = ( + load_api(segmentation, nb.Nifti1Image) if isinstance(segmentation, str) else segmentation + ) + seg_data = np.asanyarray(segmentation.dataobj, dtype=int).reshape(-1) - remap_rois = remap_rois or (len(np.unique(segmentation[segmentation > 0])) > len(labels)) + remap_rois = remap_rois or (len(np.unique(seg_data[seg_data > 0])) > len(labels)) # Map segmentation if remap_rois or lut is not None: @@ -108,12 +118,12 @@ def nifti_timeseries( lut[1:11] = 3 # WM+CSF lut[255] = 4 # Cerebellum # Apply lookup table - segmentation = lut[segmentation] + seg_data = lut[seg_data] - fgmask = segmentation > 0 - segmentation = segmentation[fgmask] - seg_dict = {} - for i in np.unique(segmentation): - seg_dict[labels[i - 1]] = np.argwhere(segmentation == i).squeeze() + fgmask = seg_data > 0 + seg_values = seg_data[fgmask] + seg_dict: dict[str, list[int]] = {} + for i in np.unique(seg_values): + seg_dict[labels[i - 1]] = list(np.argwhere(seg_values == i).squeeze()) return data[fgmask], seg_dict From 8c184713c5efcbe1d769f8783ea574c01458c044 Mon Sep 17 00:00:00 2001 From: "Christopher J. Markiewicz" Date: Thu, 19 Dec 2024 13:19:55 -0500 Subject: [PATCH 2/2] type: Annotate nireports.reportlets.utils --- nireports/reportlets/utils.py | 97 +++++++++++++++++++++++------------ 1 file changed, 63 insertions(+), 34 deletions(-) diff --git a/nireports/reportlets/utils.py b/nireports/reportlets/utils.py index f63e5074..1251b6a6 100644 --- a/nireports/reportlets/utils.py +++ b/nireports/reportlets/utils.py @@ -28,33 +28,53 @@ """Helper tools for visualization purposes.""" import base64 +import os import re import subprocess +import typing as ty import warnings from io import StringIO from pathlib import Path from shutil import which from tempfile import TemporaryDirectory +from typing import Literal as L from uuid import uuid4 +import matplotlib as mpl import nibabel as nb import numpy as np -from nipype.utils import filemanip +import numpy.typing as npt +from svgutils.transform import SVGFigure + +from ..tools.ndimage import load_api SVGNS = "http://www.w3.org/2000/svg" +G = ty.TypeVar("G", bound=np.generic) + + +class DisplayObject(ty.Protocol): + frame_axes: mpl.axes.Axes + -def robust_set_limits(data, plot_params, percentiles=(15, 99.8)): +def robust_set_limits( + data: npt.NDArray, + plot_params: dict[str, ty.Any], + percentiles: tuple[float, float] = (15, 99.8), +) -> dict[str, ty.Any]: """Set (vmax, vmin) based on percentiles of the data.""" plot_params["vmin"] = plot_params.get("vmin", np.percentile(data, percentiles[0])) plot_params["vmax"] = plot_params.get("vmax", np.percentile(data, percentiles[1])) return plot_params -def _get_limits(nifti_file, only_plot_noise=False): +def _get_limits( + nifti_file: str | npt.NDArray, + only_plot_noise: bool = False, +) -> tuple[float, float]: if isinstance(nifti_file, str): - nii = nb.as_closest_canonical(nb.load(nifti_file)) - data = nii.get_fdata() + nii = nb.as_closest_canonical(load_api(nifti_file, nb.Nifti1Image)) + data: npt.NDArray = nii.get_fdata() else: data = nifti_file @@ -71,7 +91,7 @@ def _get_limits(nifti_file, only_plot_noise=False): return vmin, vmax -def svg_compress(image, compress="auto"): +def svg_compress(image: str, compress: bool | L["auto"] = "auto") -> str: """Generate a blob SVG from a matplotlib figure, may perform compression.""" # Check availability of svgo and cwebp has_compress = all((which("svgo"), which("cwebp"))) @@ -145,20 +165,20 @@ def svg_compress(image, compress="auto"): return "".join(image_svg) # straight up giant string -def svg2str(display_object, dpi=300): +def svg2str(display_object: DisplayObject, dpi: int = 300) -> str: """Serialize a nilearn display object to string.""" from io import StringIO image_buf = StringIO() - display_object.frame_axes.figure.savefig( - image_buf, dpi=dpi, format="svg", facecolor="k", edgecolor="k" - ) - display_object.frame_axes.figure.clf() + figure = display_object.frame_axes.figure + assert isinstance(figure, mpl.figure.Figure) + figure.savefig(image_buf, dpi=dpi, format="svg", facecolor="k", edgecolor="k") + figure.clf() image_buf.seek(0) return image_buf.getvalue() -def combine_svg(svg_list, axis="vertical"): +def combine_svg(svg_list: list[SVGFigure], axis="vertical") -> SVGFigure: """ Composes the input svgs into one standalone svg """ @@ -216,7 +236,11 @@ def combine_svg(svg_list, axis="vertical"): return fig -def extract_svg(display_object, dpi=300, compress="auto"): +def extract_svg( + display_object: DisplayObject, + dpi: int = 300, + compress: bool | L["auto"] = "auto", +) -> str: """Remove the preamble of the svg files generated with nilearn.""" image_svg = svg2str(display_object, dpi) if compress is True or compress == "auto": @@ -238,14 +262,14 @@ def extract_svg(display_object, dpi=300, compress="auto"): return image_svg[start_idx:end_idx] -def _bbox(img_data, bbox_data): +def _bbox(img_data: npt.NDArray[G], bbox_data: npt.NDArray) -> npt.NDArray[G]: """Calculate the bounding box of a binary segmentation.""" B = np.argwhere(bbox_data) (ystart, xstart, zstart), (ystop, xstop, zstop) = B.min(0), B.max(0) + 1 return img_data[ystart:ystop, xstart:xstop, zstart:zstop] -def cuts_from_bbox(mask_nii, cuts=3): +def cuts_from_bbox(mask_nii: nb.Nifti1Image, cuts: int = 3) -> dict[str, list[float]]: """Find equi-spaced cuts for presenting images.""" mask_data = np.asanyarray(mask_nii.dataobj) > 0.0 @@ -291,7 +315,9 @@ def cuts_from_bbox(mask_nii, cuts=3): return {k: list(v) for k, v in zip(["x", "y", "z"], np.around(ras_coords, 3))} -def _3d_in_file(in_file): +def _3d_in_file( + in_file: nb.Nifti1Image | str | os.PathLike | list[str | os.PathLike], +) -> nb.Nifti1Image: """if self.inputs.in_file is 3d, return it. if 4d, pick an arbitrary volume and return that. @@ -300,12 +326,11 @@ def _3d_in_file(in_file): """ from nilearn import image as nlimage - in_file = filemanip.filename_to_list(in_file)[0] + if isinstance(in_file, list): + in_file = in_file[0] - try: - in_file = nb.load(in_file) - except AttributeError: - in_file = in_file + if not isinstance(in_file, nb.Nifti1Image): + in_file = load_api(in_file, nb.Nifti1Image) if len(in_file.shape) == 3: return in_file @@ -313,7 +338,12 @@ def _3d_in_file(in_file): return nlimage.index_img(in_file, 0) -def compose_view(bg_svgs, fg_svgs, ref=0, out_file="report.svg"): +def compose_view( + bg_svgs: list[SVGFigure], + fg_svgs: list[SVGFigure], + ref: int = 0, + out_file: str | os.PathLike[str] = "report.svg", +) -> str: """ Compose svgs into one standalone svg with CSS flickering animation. @@ -338,9 +368,13 @@ def compose_view(bg_svgs, fg_svgs, ref=0, out_file="report.svg"): return str(out_file) -def _compose_view(bg_svgs, fg_svgs, ref=0): +def _compose_view( + bg_svgs: list[SVGFigure], + fg_svgs: list[SVGFigure], + ref: int = 0, +) -> list[str]: from svgutils.compose import Unit - from svgutils.transform import GroupElement, SVGFigure + from svgutils.transform import GroupElement if fg_svgs is None: fg_svgs = [] @@ -350,16 +384,11 @@ def _compose_view(bg_svgs, fg_svgs, ref=0): roots = [f.getroot() for f in svgs] # Query the size of each - sizes = [] - for f in svgs: - viewbox = [float(v) for v in f.root.get("viewBox").split(" ")] - width = int(viewbox[2]) - height = int(viewbox[3]) - sizes.append((width, height)) + sizes = np.array( + [[int(float(val)) for val in f.root.get("viewBox").split(" ")[2:4]] for f in svgs] + ) nsvgs = len(bg_svgs) - sizes = np.array(sizes) - # Calculate the scale to fit all widths width = sizes[ref, 0] scales = width / sizes[:, 0] @@ -416,7 +445,7 @@ def _compose_view(bg_svgs, fg_svgs, ref=0): return svg -def transform_to_2d(data, max_axis): +def transform_to_2d(data: npt.NDArray, max_axis: int) -> npt.NDArray: """ Projects 3d data cube along one axis using maximum intensity with preservation of the signs. Adapted from nilearn. @@ -440,7 +469,7 @@ def transform_to_2d(data, max_axis): return np.rot90(maximum_intensity_data) -def get_parula(): +def get_parula() -> mpl.colors.LinearSegmentedColormap: """Generate a 'parula' colormap.""" from matplotlib.colors import LinearSegmentedColormap