Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

type: Annotate tools and reportlets.utils #147

Merged
merged 2 commits into from
Dec 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
97 changes: 63 additions & 34 deletions nireports/reportlets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,33 +28,53 @@
"""Helper tools for visualization purposes."""

import base64
import os
import re
import subprocess
import typing as ty
import warnings
from io import StringIO
from pathlib import Path
from shutil import which
from tempfile import TemporaryDirectory
from typing import Literal as L
from uuid import uuid4

import matplotlib as mpl
import nibabel as nb
import numpy as np
from nipype.utils import filemanip
import numpy.typing as npt
from svgutils.transform import SVGFigure

from ..tools.ndimage import load_api

SVGNS = "http://www.w3.org/2000/svg"

G = ty.TypeVar("G", bound=np.generic)


class DisplayObject(ty.Protocol):
frame_axes: mpl.axes.Axes


def robust_set_limits(data, plot_params, percentiles=(15, 99.8)):
def robust_set_limits(
data: npt.NDArray,
plot_params: dict[str, ty.Any],
percentiles: tuple[float, float] = (15, 99.8),
) -> dict[str, ty.Any]:
"""Set (vmax, vmin) based on percentiles of the data."""
plot_params["vmin"] = plot_params.get("vmin", np.percentile(data, percentiles[0]))
plot_params["vmax"] = plot_params.get("vmax", np.percentile(data, percentiles[1]))
return plot_params


def _get_limits(nifti_file, only_plot_noise=False):
def _get_limits(
nifti_file: str | npt.NDArray,
only_plot_noise: bool = False,
) -> tuple[float, float]:
if isinstance(nifti_file, str):
nii = nb.as_closest_canonical(nb.load(nifti_file))
data = nii.get_fdata()
nii = nb.as_closest_canonical(load_api(nifti_file, nb.Nifti1Image))
data: npt.NDArray = nii.get_fdata()

Check warning on line 77 in nireports/reportlets/utils.py

View check run for this annotation

Codecov / codecov/patch

nireports/reportlets/utils.py#L76-L77

Added lines #L76 - L77 were not covered by tests
else:
data = nifti_file

Expand All @@ -71,7 +91,7 @@
return vmin, vmax


def svg_compress(image, compress="auto"):
def svg_compress(image: str, compress: bool | L["auto"] = "auto") -> str:
"""Generate a blob SVG from a matplotlib figure, may perform compression."""
# Check availability of svgo and cwebp
has_compress = all((which("svgo"), which("cwebp")))
Expand Down Expand Up @@ -145,20 +165,20 @@
return "".join(image_svg) # straight up giant string


def svg2str(display_object, dpi=300):
def svg2str(display_object: DisplayObject, dpi: int = 300) -> str:
"""Serialize a nilearn display object to string."""
from io import StringIO

image_buf = StringIO()
display_object.frame_axes.figure.savefig(
image_buf, dpi=dpi, format="svg", facecolor="k", edgecolor="k"
)
display_object.frame_axes.figure.clf()
figure = display_object.frame_axes.figure
assert isinstance(figure, mpl.figure.Figure)
figure.savefig(image_buf, dpi=dpi, format="svg", facecolor="k", edgecolor="k")
figure.clf()
image_buf.seek(0)
return image_buf.getvalue()


def combine_svg(svg_list, axis="vertical"):
def combine_svg(svg_list: list[SVGFigure], axis="vertical") -> SVGFigure:
"""
Composes the input svgs into one standalone svg
"""
Expand Down Expand Up @@ -216,7 +236,11 @@
return fig


def extract_svg(display_object, dpi=300, compress="auto"):
def extract_svg(
display_object: DisplayObject,
dpi: int = 300,
compress: bool | L["auto"] = "auto",
) -> str:
"""Remove the preamble of the svg files generated with nilearn."""
image_svg = svg2str(display_object, dpi)
if compress is True or compress == "auto":
Expand All @@ -238,14 +262,14 @@
return image_svg[start_idx:end_idx]


def _bbox(img_data, bbox_data):
def _bbox(img_data: npt.NDArray[G], bbox_data: npt.NDArray) -> npt.NDArray[G]:
"""Calculate the bounding box of a binary segmentation."""
B = np.argwhere(bbox_data)
(ystart, xstart, zstart), (ystop, xstop, zstop) = B.min(0), B.max(0) + 1
return img_data[ystart:ystop, xstart:xstop, zstart:zstop]


def cuts_from_bbox(mask_nii, cuts=3):
def cuts_from_bbox(mask_nii: nb.Nifti1Image, cuts: int = 3) -> dict[str, list[float]]:
"""Find equi-spaced cuts for presenting images."""
mask_data = np.asanyarray(mask_nii.dataobj) > 0.0

Expand Down Expand Up @@ -291,7 +315,9 @@
return {k: list(v) for k, v in zip(["x", "y", "z"], np.around(ras_coords, 3))}


def _3d_in_file(in_file):
def _3d_in_file(
in_file: nb.Nifti1Image | str | os.PathLike | list[str | os.PathLike],
) -> nb.Nifti1Image:
"""if self.inputs.in_file is 3d, return it.
if 4d, pick an arbitrary volume and return that.

Expand All @@ -300,20 +326,24 @@
"""
from nilearn import image as nlimage

in_file = filemanip.filename_to_list(in_file)[0]
if isinstance(in_file, list):
in_file = in_file[0]

Check warning on line 330 in nireports/reportlets/utils.py

View check run for this annotation

Codecov / codecov/patch

nireports/reportlets/utils.py#L330

Added line #L330 was not covered by tests

try:
in_file = nb.load(in_file)
except AttributeError:
in_file = in_file
if not isinstance(in_file, nb.Nifti1Image):
in_file = load_api(in_file, nb.Nifti1Image)

if len(in_file.shape) == 3:
return in_file

return nlimage.index_img(in_file, 0)


def compose_view(bg_svgs, fg_svgs, ref=0, out_file="report.svg"):
def compose_view(
bg_svgs: list[SVGFigure],
fg_svgs: list[SVGFigure],
ref: int = 0,
out_file: str | os.PathLike[str] = "report.svg",
) -> str:
"""
Compose svgs into one standalone svg with CSS flickering animation.

Expand All @@ -338,9 +368,13 @@
return str(out_file)


def _compose_view(bg_svgs, fg_svgs, ref=0):
def _compose_view(
bg_svgs: list[SVGFigure],
fg_svgs: list[SVGFigure],
ref: int = 0,
) -> list[str]:
from svgutils.compose import Unit
from svgutils.transform import GroupElement, SVGFigure
from svgutils.transform import GroupElement

if fg_svgs is None:
fg_svgs = []
Expand All @@ -350,16 +384,11 @@
roots = [f.getroot() for f in svgs]

# Query the size of each
sizes = []
for f in svgs:
viewbox = [float(v) for v in f.root.get("viewBox").split(" ")]
width = int(viewbox[2])
height = int(viewbox[3])
sizes.append((width, height))
sizes = np.array(
[[int(float(val)) for val in f.root.get("viewBox").split(" ")[2:4]] for f in svgs]
)
nsvgs = len(bg_svgs)

sizes = np.array(sizes)

# Calculate the scale to fit all widths
width = sizes[ref, 0]
scales = width / sizes[:, 0]
Expand Down Expand Up @@ -416,7 +445,7 @@
return svg


def transform_to_2d(data, max_axis):
def transform_to_2d(data: npt.NDArray, max_axis: int) -> npt.NDArray:
"""
Projects 3d data cube along one axis using maximum intensity with
preservation of the signs. Adapted from nilearn.
Expand All @@ -440,7 +469,7 @@
return np.rot90(maximum_intensity_data)


def get_parula():
def get_parula() -> mpl.colors.LinearSegmentedColormap:
"""Generate a 'parula' colormap."""
from matplotlib.colors import LinearSegmentedColormap

Expand Down
33 changes: 27 additions & 6 deletions nireports/tools/ndimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,21 +27,32 @@
# niworkflows/utils/images.py
"""Tooling to manipulate n-dimensional images."""

import os
import typing as ty

import nibabel as nb
import numpy as np
import numpy.typing as npt
from nibabel.spatialimages import SpatialImage

ImgT = ty.TypeVar("ImgT", bound=nb.filebasedimages.FileBasedImage)
SpatImgT = ty.TypeVar("SpatImgT", bound=SpatialImage)
Mat = npt.NDArray[np.float64]


def rotation2canonical(img):
def rotation2canonical(img: SpatialImage) -> Mat | None:
"""Calculate the rotation w.r.t. cardinal axes of input image."""
img = nb.as_closest_canonical(img)
# XXX: SpatialImage.affine needs to be typed
affine: Mat = img.affine
newaff = np.diag(img.header.get_zooms()[:3])
r = newaff @ np.linalg.pinv(img.affine[:3, :3])
r = newaff @ np.linalg.pinv(affine[:3, :3])
if np.allclose(r, np.eye(3)):
return None
return r


def rotate_affine(img, rot=None):
def rotate_affine(img: SpatImgT, rot: Mat | None = None) -> SpatImgT:
"""Rewrite the affine of a spatial image."""
if rot is None:
return img
Expand All @@ -52,11 +63,21 @@
return img.__class__(img.dataobj, affine, img.header)


def _get_values_inside_a_mask(main_file, mask_file):
main_nii = nb.load(main_file)
def load_api(path: str | os.PathLike[str], api: type[ImgT]) -> ImgT:
img = nb.load(path)
if not isinstance(img, api):
raise TypeError(f"File {path} does not implement {api} interface")

Check warning on line 69 in nireports/tools/ndimage.py

View check run for this annotation

Codecov / codecov/patch

nireports/tools/ndimage.py#L69

Added line #L69 was not covered by tests
return img


def _get_values_inside_a_mask(
main_file: str | os.PathLike[str],
mask_file: str | os.PathLike[str],
) -> npt.NDArray[np.float64]:
main_nii = load_api(main_file, SpatialImage)

Check warning on line 77 in nireports/tools/ndimage.py

View check run for this annotation

Codecov / codecov/patch

nireports/tools/ndimage.py#L77

Added line #L77 was not covered by tests
main_data = main_nii.get_fdata()
nan_mask = np.logical_not(np.isnan(main_data))
mask = nb.load(mask_file).get_fdata() > 0
mask = load_api(mask_file, SpatialImage).get_fdata() > 0

Check warning on line 80 in nireports/tools/ndimage.py

View check run for this annotation

Codecov / codecov/patch

nireports/tools/ndimage.py#L80

Added line #L80 was not covered by tests

data = main_data[np.logical_and(nan_mask, mask)]
return data
60 changes: 35 additions & 25 deletions nireports/tools/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,16 @@
# niworkflows/utils/timeseries.py
"""Extracting signals from NIfTI and CIFTI2 files."""

from collections.abc import Sequence

import nibabel as nb
import numpy as np
import numpy.typing as npt

from .ndimage import load_api


def get_tr(img):
def get_tr(img: nb.Nifti1Image | nb.Cifti2Image) -> float:
"""
Attempt to extract repetition time from NIfTI/CIFTI header.

Expand All @@ -49,17 +54,18 @@ def get_tr(img):
2.0

"""

try:
if isinstance(img, nb.Cifti2Image):
return img.header.matrix.get_index_map(0).series_step
except AttributeError:
else:
return img.header.get_zooms()[-1]
raise RuntimeError("Could not extract TR - unknown data structure type")


def cifti_timeseries(dataset):
def cifti_timeseries(
dataset: str | nb.Cifti2Image,
) -> tuple[npt.NDArray[np.float32], dict[str, list[int]]]:
"""Extract timeseries from CIFTI2 dataset."""
dataset = nb.load(dataset) if isinstance(dataset, str) else dataset
dataset = load_api(dataset, nb.Cifti2Image) if isinstance(dataset, str) else dataset

if dataset.nifti_header.get_intent()[0] != "ConnDenseSeries":
raise ValueError("Not a dense timeseries")
Expand All @@ -71,33 +77,37 @@ def cifti_timeseries(dataset):
"CIFTI_STRUCTURE_CEREBELLUM_LEFT": "CbL",
"CIFTI_STRUCTURE_CEREBELLUM_RIGHT": "CbR",
}
seg = {label: [] for label in list(labels.values()) + ["Other"]}
seg: dict[str, list[int]] = {label: [] for label in list(labels.values()) + ["Other"]}
for bm in matrix.get_index_map(1).brain_models:
label = labels.get(bm.brain_structure, "Other")
seg[label] += list(range(bm.index_offset, bm.index_offset + bm.index_count))

return dataset.get_fdata(dtype="float32").T, seg
return dataset.get_fdata(dtype=np.float32).T, seg


def nifti_timeseries(
dataset,
segmentation=None,
labels=("Ctx GM", "dGM", "WM+CSF", "Cb", "Crown"),
remap_rois=False,
lut=None,
):
dataset: str | nb.Nifti1Image,
segmentation: str | nb.Nifti1Image | None = None,
labels: Sequence[str] = ("Ctx GM", "dGM", "WM+CSF", "Cb", "Crown"),
remap_rois: bool = False,
lut: npt.NDArray[np.uint8] | None = None,
) -> tuple[npt.NDArray[np.float32], dict[str, list[int]] | None]:
"""Extract timeseries from NIfTI1/2 datasets."""
dataset = nb.load(dataset) if isinstance(dataset, str) else dataset
data = dataset.get_fdata(dtype="float32").reshape((-1, dataset.shape[-1]))
dataset = load_api(dataset, nb.Nifti1Image) if isinstance(dataset, str) else dataset
data: npt.NDArray[np.float32] = dataset.get_fdata(dtype="float32").reshape(
(-1, dataset.shape[-1])
)

if segmentation is None:
return data, None

# Open NIfTI and extract numpy array
segmentation = nb.load(segmentation) if isinstance(segmentation, str) else segmentation
segmentation = np.asanyarray(segmentation.dataobj, dtype=int).reshape(-1)
segmentation = (
load_api(segmentation, nb.Nifti1Image) if isinstance(segmentation, str) else segmentation
)
seg_data = np.asanyarray(segmentation.dataobj, dtype=int).reshape(-1)

remap_rois = remap_rois or (len(np.unique(segmentation[segmentation > 0])) > len(labels))
remap_rois = remap_rois or (len(np.unique(seg_data[seg_data > 0])) > len(labels))

# Map segmentation
if remap_rois or lut is not None:
Expand All @@ -108,12 +118,12 @@ def nifti_timeseries(
lut[1:11] = 3 # WM+CSF
lut[255] = 4 # Cerebellum
# Apply lookup table
segmentation = lut[segmentation]
seg_data = lut[seg_data]

fgmask = segmentation > 0
segmentation = segmentation[fgmask]
seg_dict = {}
for i in np.unique(segmentation):
seg_dict[labels[i - 1]] = np.argwhere(segmentation == i).squeeze()
fgmask = seg_data > 0
seg_values = seg_data[fgmask]
seg_dict: dict[str, list[int]] = {}
for i in np.unique(seg_values):
seg_dict[labels[i - 1]] = list(np.argwhere(seg_values == i).squeeze())

return data[fgmask], seg_dict
Loading