Skip to content

Commit

Permalink
Merge pull request #147 from effigies/typ/annotations
Browse files Browse the repository at this point in the history
type: Annotate tools and reportlets.utils
  • Loading branch information
jhlegarreta authored Dec 26, 2024
2 parents b79af0f + 8c18471 commit e89ab63
Show file tree
Hide file tree
Showing 3 changed files with 125 additions and 65 deletions.
97 changes: 63 additions & 34 deletions nireports/reportlets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,33 +28,53 @@
"""Helper tools for visualization purposes."""

import base64
import os
import re
import subprocess
import typing as ty
import warnings
from io import StringIO
from pathlib import Path
from shutil import which
from tempfile import TemporaryDirectory
from typing import Literal as L
from uuid import uuid4

import matplotlib as mpl
import nibabel as nb
import numpy as np
from nipype.utils import filemanip
import numpy.typing as npt
from svgutils.transform import SVGFigure

from ..tools.ndimage import load_api

SVGNS = "http://www.w3.org/2000/svg"

G = ty.TypeVar("G", bound=np.generic)


class DisplayObject(ty.Protocol):
frame_axes: mpl.axes.Axes


def robust_set_limits(data, plot_params, percentiles=(15, 99.8)):
def robust_set_limits(
data: npt.NDArray,
plot_params: dict[str, ty.Any],
percentiles: tuple[float, float] = (15, 99.8),
) -> dict[str, ty.Any]:
"""Set (vmax, vmin) based on percentiles of the data."""
plot_params["vmin"] = plot_params.get("vmin", np.percentile(data, percentiles[0]))
plot_params["vmax"] = plot_params.get("vmax", np.percentile(data, percentiles[1]))
return plot_params


def _get_limits(nifti_file, only_plot_noise=False):
def _get_limits(
nifti_file: str | npt.NDArray,
only_plot_noise: bool = False,
) -> tuple[float, float]:
if isinstance(nifti_file, str):
nii = nb.as_closest_canonical(nb.load(nifti_file))
data = nii.get_fdata()
nii = nb.as_closest_canonical(load_api(nifti_file, nb.Nifti1Image))
data: npt.NDArray = nii.get_fdata()
else:
data = nifti_file

Expand All @@ -71,7 +91,7 @@ def _get_limits(nifti_file, only_plot_noise=False):
return vmin, vmax


def svg_compress(image, compress="auto"):
def svg_compress(image: str, compress: bool | L["auto"] = "auto") -> str:
"""Generate a blob SVG from a matplotlib figure, may perform compression."""
# Check availability of svgo and cwebp
has_compress = all((which("svgo"), which("cwebp")))
Expand Down Expand Up @@ -145,20 +165,20 @@ def svg_compress(image, compress="auto"):
return "".join(image_svg) # straight up giant string


def svg2str(display_object, dpi=300):
def svg2str(display_object: DisplayObject, dpi: int = 300) -> str:
"""Serialize a nilearn display object to string."""
from io import StringIO

image_buf = StringIO()
display_object.frame_axes.figure.savefig(
image_buf, dpi=dpi, format="svg", facecolor="k", edgecolor="k"
)
display_object.frame_axes.figure.clf()
figure = display_object.frame_axes.figure
assert isinstance(figure, mpl.figure.Figure)
figure.savefig(image_buf, dpi=dpi, format="svg", facecolor="k", edgecolor="k")
figure.clf()
image_buf.seek(0)
return image_buf.getvalue()


def combine_svg(svg_list, axis="vertical"):
def combine_svg(svg_list: list[SVGFigure], axis="vertical") -> SVGFigure:
"""
Composes the input svgs into one standalone svg
"""
Expand Down Expand Up @@ -216,7 +236,11 @@ def combine_svg(svg_list, axis="vertical"):
return fig


def extract_svg(display_object, dpi=300, compress="auto"):
def extract_svg(
display_object: DisplayObject,
dpi: int = 300,
compress: bool | L["auto"] = "auto",
) -> str:
"""Remove the preamble of the svg files generated with nilearn."""
image_svg = svg2str(display_object, dpi)
if compress is True or compress == "auto":
Expand All @@ -238,14 +262,14 @@ def extract_svg(display_object, dpi=300, compress="auto"):
return image_svg[start_idx:end_idx]


def _bbox(img_data, bbox_data):
def _bbox(img_data: npt.NDArray[G], bbox_data: npt.NDArray) -> npt.NDArray[G]:
"""Calculate the bounding box of a binary segmentation."""
B = np.argwhere(bbox_data)
(ystart, xstart, zstart), (ystop, xstop, zstop) = B.min(0), B.max(0) + 1
return img_data[ystart:ystop, xstart:xstop, zstart:zstop]


def cuts_from_bbox(mask_nii, cuts=3):
def cuts_from_bbox(mask_nii: nb.Nifti1Image, cuts: int = 3) -> dict[str, list[float]]:
"""Find equi-spaced cuts for presenting images."""
mask_data = np.asanyarray(mask_nii.dataobj) > 0.0

Expand Down Expand Up @@ -291,7 +315,9 @@ def cuts_from_bbox(mask_nii, cuts=3):
return {k: list(v) for k, v in zip(["x", "y", "z"], np.around(ras_coords, 3))}


def _3d_in_file(in_file):
def _3d_in_file(
in_file: nb.Nifti1Image | str | os.PathLike | list[str | os.PathLike],
) -> nb.Nifti1Image:
"""if self.inputs.in_file is 3d, return it.
if 4d, pick an arbitrary volume and return that.
Expand All @@ -300,20 +326,24 @@ def _3d_in_file(in_file):
"""
from nilearn import image as nlimage

in_file = filemanip.filename_to_list(in_file)[0]
if isinstance(in_file, list):
in_file = in_file[0]

try:
in_file = nb.load(in_file)
except AttributeError:
in_file = in_file
if not isinstance(in_file, nb.Nifti1Image):
in_file = load_api(in_file, nb.Nifti1Image)

if len(in_file.shape) == 3:
return in_file

return nlimage.index_img(in_file, 0)


def compose_view(bg_svgs, fg_svgs, ref=0, out_file="report.svg"):
def compose_view(
bg_svgs: list[SVGFigure],
fg_svgs: list[SVGFigure],
ref: int = 0,
out_file: str | os.PathLike[str] = "report.svg",
) -> str:
"""
Compose svgs into one standalone svg with CSS flickering animation.
Expand All @@ -338,9 +368,13 @@ def compose_view(bg_svgs, fg_svgs, ref=0, out_file="report.svg"):
return str(out_file)


def _compose_view(bg_svgs, fg_svgs, ref=0):
def _compose_view(
bg_svgs: list[SVGFigure],
fg_svgs: list[SVGFigure],
ref: int = 0,
) -> list[str]:
from svgutils.compose import Unit
from svgutils.transform import GroupElement, SVGFigure
from svgutils.transform import GroupElement

if fg_svgs is None:
fg_svgs = []
Expand All @@ -350,16 +384,11 @@ def _compose_view(bg_svgs, fg_svgs, ref=0):
roots = [f.getroot() for f in svgs]

# Query the size of each
sizes = []
for f in svgs:
viewbox = [float(v) for v in f.root.get("viewBox").split(" ")]
width = int(viewbox[2])
height = int(viewbox[3])
sizes.append((width, height))
sizes = np.array(
[[int(float(val)) for val in f.root.get("viewBox").split(" ")[2:4]] for f in svgs]
)
nsvgs = len(bg_svgs)

sizes = np.array(sizes)

# Calculate the scale to fit all widths
width = sizes[ref, 0]
scales = width / sizes[:, 0]
Expand Down Expand Up @@ -416,7 +445,7 @@ def _compose_view(bg_svgs, fg_svgs, ref=0):
return svg


def transform_to_2d(data, max_axis):
def transform_to_2d(data: npt.NDArray, max_axis: int) -> npt.NDArray:
"""
Projects 3d data cube along one axis using maximum intensity with
preservation of the signs. Adapted from nilearn.
Expand All @@ -440,7 +469,7 @@ def transform_to_2d(data, max_axis):
return np.rot90(maximum_intensity_data)


def get_parula():
def get_parula() -> mpl.colors.LinearSegmentedColormap:
"""Generate a 'parula' colormap."""
from matplotlib.colors import LinearSegmentedColormap

Expand Down
33 changes: 27 additions & 6 deletions nireports/tools/ndimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,32 @@
# niworkflows/utils/images.py
"""Tooling to manipulate n-dimensional images."""

import os
import typing as ty

import nibabel as nb
import numpy as np
import numpy.typing as npt
from nibabel.spatialimages import SpatialImage

ImgT = ty.TypeVar("ImgT", bound=nb.filebasedimages.FileBasedImage)
SpatImgT = ty.TypeVar("SpatImgT", bound=SpatialImage)
Mat = npt.NDArray[np.float64]


def rotation2canonical(img):
def rotation2canonical(img: SpatialImage) -> Mat | None:
"""Calculate the rotation w.r.t. cardinal axes of input image."""
img = nb.as_closest_canonical(img)
# XXX: SpatialImage.affine needs to be typed
affine: Mat = img.affine
newaff = np.diag(img.header.get_zooms()[:3])
r = newaff @ np.linalg.pinv(img.affine[:3, :3])
r = newaff @ np.linalg.pinv(affine[:3, :3])
if np.allclose(r, np.eye(3)):
return None
return r


def rotate_affine(img, rot=None):
def rotate_affine(img: SpatImgT, rot: Mat | None = None) -> SpatImgT:
"""Rewrite the affine of a spatial image."""
if rot is None:
return img
Expand All @@ -52,11 +63,21 @@ def rotate_affine(img, rot=None):
return img.__class__(img.dataobj, affine, img.header)


def _get_values_inside_a_mask(main_file, mask_file):
main_nii = nb.load(main_file)
def load_api(path: str | os.PathLike[str], api: type[ImgT]) -> ImgT:
img = nb.load(path)
if not isinstance(img, api):
raise TypeError(f"File {path} does not implement {api} interface")
return img


def _get_values_inside_a_mask(
main_file: str | os.PathLike[str],
mask_file: str | os.PathLike[str],
) -> npt.NDArray[np.float64]:
main_nii = load_api(main_file, SpatialImage)
main_data = main_nii.get_fdata()
nan_mask = np.logical_not(np.isnan(main_data))
mask = nb.load(mask_file).get_fdata() > 0
mask = load_api(mask_file, SpatialImage).get_fdata() > 0

data = main_data[np.logical_and(nan_mask, mask)]
return data
60 changes: 35 additions & 25 deletions nireports/tools/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,16 @@
# niworkflows/utils/timeseries.py
"""Extracting signals from NIfTI and CIFTI2 files."""

from collections.abc import Sequence

import nibabel as nb
import numpy as np
import numpy.typing as npt

from .ndimage import load_api


def get_tr(img):
def get_tr(img: nb.Nifti1Image | nb.Cifti2Image) -> float:
"""
Attempt to extract repetition time from NIfTI/CIFTI header.
Expand All @@ -49,17 +54,18 @@ def get_tr(img):
2.0
"""

try:
if isinstance(img, nb.Cifti2Image):
return img.header.matrix.get_index_map(0).series_step
except AttributeError:
else:
return img.header.get_zooms()[-1]
raise RuntimeError("Could not extract TR - unknown data structure type")


def cifti_timeseries(dataset):
def cifti_timeseries(
dataset: str | nb.Cifti2Image,
) -> tuple[npt.NDArray[np.float32], dict[str, list[int]]]:
"""Extract timeseries from CIFTI2 dataset."""
dataset = nb.load(dataset) if isinstance(dataset, str) else dataset
dataset = load_api(dataset, nb.Cifti2Image) if isinstance(dataset, str) else dataset

if dataset.nifti_header.get_intent()[0] != "ConnDenseSeries":
raise ValueError("Not a dense timeseries")
Expand All @@ -71,33 +77,37 @@ def cifti_timeseries(dataset):
"CIFTI_STRUCTURE_CEREBELLUM_LEFT": "CbL",
"CIFTI_STRUCTURE_CEREBELLUM_RIGHT": "CbR",
}
seg = {label: [] for label in list(labels.values()) + ["Other"]}
seg: dict[str, list[int]] = {label: [] for label in list(labels.values()) + ["Other"]}
for bm in matrix.get_index_map(1).brain_models:
label = labels.get(bm.brain_structure, "Other")
seg[label] += list(range(bm.index_offset, bm.index_offset + bm.index_count))

return dataset.get_fdata(dtype="float32").T, seg
return dataset.get_fdata(dtype=np.float32).T, seg


def nifti_timeseries(
dataset,
segmentation=None,
labels=("Ctx GM", "dGM", "WM+CSF", "Cb", "Crown"),
remap_rois=False,
lut=None,
):
dataset: str | nb.Nifti1Image,
segmentation: str | nb.Nifti1Image | None = None,
labels: Sequence[str] = ("Ctx GM", "dGM", "WM+CSF", "Cb", "Crown"),
remap_rois: bool = False,
lut: npt.NDArray[np.uint8] | None = None,
) -> tuple[npt.NDArray[np.float32], dict[str, list[int]] | None]:
"""Extract timeseries from NIfTI1/2 datasets."""
dataset = nb.load(dataset) if isinstance(dataset, str) else dataset
data = dataset.get_fdata(dtype="float32").reshape((-1, dataset.shape[-1]))
dataset = load_api(dataset, nb.Nifti1Image) if isinstance(dataset, str) else dataset
data: npt.NDArray[np.float32] = dataset.get_fdata(dtype="float32").reshape(
(-1, dataset.shape[-1])
)

if segmentation is None:
return data, None

# Open NIfTI and extract numpy array
segmentation = nb.load(segmentation) if isinstance(segmentation, str) else segmentation
segmentation = np.asanyarray(segmentation.dataobj, dtype=int).reshape(-1)
segmentation = (
load_api(segmentation, nb.Nifti1Image) if isinstance(segmentation, str) else segmentation
)
seg_data = np.asanyarray(segmentation.dataobj, dtype=int).reshape(-1)

remap_rois = remap_rois or (len(np.unique(segmentation[segmentation > 0])) > len(labels))
remap_rois = remap_rois or (len(np.unique(seg_data[seg_data > 0])) > len(labels))

# Map segmentation
if remap_rois or lut is not None:
Expand All @@ -108,12 +118,12 @@ def nifti_timeseries(
lut[1:11] = 3 # WM+CSF
lut[255] = 4 # Cerebellum
# Apply lookup table
segmentation = lut[segmentation]
seg_data = lut[seg_data]

fgmask = segmentation > 0
segmentation = segmentation[fgmask]
seg_dict = {}
for i in np.unique(segmentation):
seg_dict[labels[i - 1]] = np.argwhere(segmentation == i).squeeze()
fgmask = seg_data > 0
seg_values = seg_data[fgmask]
seg_dict: dict[str, list[int]] = {}
for i in np.unique(seg_values):
seg_dict[labels[i - 1]] = list(np.argwhere(seg_values == i).squeeze())

return data[fgmask], seg_dict

0 comments on commit e89ab63

Please sign in to comment.