Skip to content

Commit

Permalink
type: Annotate nireports.reportlets.utils
Browse files Browse the repository at this point in the history
  • Loading branch information
effigies committed Dec 19, 2024
1 parent 44d0ff6 commit 8c18471
Showing 1 changed file with 63 additions and 34 deletions.
97 changes: 63 additions & 34 deletions nireports/reportlets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,33 +28,53 @@
"""Helper tools for visualization purposes."""

import base64
import os
import re
import subprocess
import typing as ty
import warnings
from io import StringIO
from pathlib import Path
from shutil import which
from tempfile import TemporaryDirectory
from typing import Literal as L
from uuid import uuid4

import matplotlib as mpl
import nibabel as nb
import numpy as np
from nipype.utils import filemanip
import numpy.typing as npt
from svgutils.transform import SVGFigure

from ..tools.ndimage import load_api

SVGNS = "http://www.w3.org/2000/svg"

G = ty.TypeVar("G", bound=np.generic)


class DisplayObject(ty.Protocol):
frame_axes: mpl.axes.Axes


def robust_set_limits(data, plot_params, percentiles=(15, 99.8)):
def robust_set_limits(
data: npt.NDArray,
plot_params: dict[str, ty.Any],
percentiles: tuple[float, float] = (15, 99.8),
) -> dict[str, ty.Any]:
"""Set (vmax, vmin) based on percentiles of the data."""
plot_params["vmin"] = plot_params.get("vmin", np.percentile(data, percentiles[0]))
plot_params["vmax"] = plot_params.get("vmax", np.percentile(data, percentiles[1]))
return plot_params


def _get_limits(nifti_file, only_plot_noise=False):
def _get_limits(
nifti_file: str | npt.NDArray,
only_plot_noise: bool = False,
) -> tuple[float, float]:
if isinstance(nifti_file, str):
nii = nb.as_closest_canonical(nb.load(nifti_file))
data = nii.get_fdata()
nii = nb.as_closest_canonical(load_api(nifti_file, nb.Nifti1Image))
data: npt.NDArray = nii.get_fdata()

Check warning on line 77 in nireports/reportlets/utils.py

View check run for this annotation

Codecov / codecov/patch

nireports/reportlets/utils.py#L76-L77

Added lines #L76 - L77 were not covered by tests
else:
data = nifti_file

Expand All @@ -71,7 +91,7 @@ def _get_limits(nifti_file, only_plot_noise=False):
return vmin, vmax


def svg_compress(image, compress="auto"):
def svg_compress(image: str, compress: bool | L["auto"] = "auto") -> str:
"""Generate a blob SVG from a matplotlib figure, may perform compression."""
# Check availability of svgo and cwebp
has_compress = all((which("svgo"), which("cwebp")))
Expand Down Expand Up @@ -145,20 +165,20 @@ def svg_compress(image, compress="auto"):
return "".join(image_svg) # straight up giant string


def svg2str(display_object, dpi=300):
def svg2str(display_object: DisplayObject, dpi: int = 300) -> str:
"""Serialize a nilearn display object to string."""
from io import StringIO

image_buf = StringIO()
display_object.frame_axes.figure.savefig(
image_buf, dpi=dpi, format="svg", facecolor="k", edgecolor="k"
)
display_object.frame_axes.figure.clf()
figure = display_object.frame_axes.figure
assert isinstance(figure, mpl.figure.Figure)
figure.savefig(image_buf, dpi=dpi, format="svg", facecolor="k", edgecolor="k")
figure.clf()
image_buf.seek(0)
return image_buf.getvalue()


def combine_svg(svg_list, axis="vertical"):
def combine_svg(svg_list: list[SVGFigure], axis="vertical") -> SVGFigure:
"""
Composes the input svgs into one standalone svg
"""
Expand Down Expand Up @@ -216,7 +236,11 @@ def combine_svg(svg_list, axis="vertical"):
return fig


def extract_svg(display_object, dpi=300, compress="auto"):
def extract_svg(
display_object: DisplayObject,
dpi: int = 300,
compress: bool | L["auto"] = "auto",
) -> str:
"""Remove the preamble of the svg files generated with nilearn."""
image_svg = svg2str(display_object, dpi)
if compress is True or compress == "auto":
Expand All @@ -238,14 +262,14 @@ def extract_svg(display_object, dpi=300, compress="auto"):
return image_svg[start_idx:end_idx]


def _bbox(img_data, bbox_data):
def _bbox(img_data: npt.NDArray[G], bbox_data: npt.NDArray) -> npt.NDArray[G]:
"""Calculate the bounding box of a binary segmentation."""
B = np.argwhere(bbox_data)
(ystart, xstart, zstart), (ystop, xstop, zstop) = B.min(0), B.max(0) + 1
return img_data[ystart:ystop, xstart:xstop, zstart:zstop]


def cuts_from_bbox(mask_nii, cuts=3):
def cuts_from_bbox(mask_nii: nb.Nifti1Image, cuts: int = 3) -> dict[str, list[float]]:
"""Find equi-spaced cuts for presenting images."""
mask_data = np.asanyarray(mask_nii.dataobj) > 0.0

Expand Down Expand Up @@ -291,7 +315,9 @@ def cuts_from_bbox(mask_nii, cuts=3):
return {k: list(v) for k, v in zip(["x", "y", "z"], np.around(ras_coords, 3))}


def _3d_in_file(in_file):
def _3d_in_file(
in_file: nb.Nifti1Image | str | os.PathLike | list[str | os.PathLike],
) -> nb.Nifti1Image:
"""if self.inputs.in_file is 3d, return it.
if 4d, pick an arbitrary volume and return that.
Expand All @@ -300,20 +326,24 @@ def _3d_in_file(in_file):
"""
from nilearn import image as nlimage

in_file = filemanip.filename_to_list(in_file)[0]
if isinstance(in_file, list):
in_file = in_file[0]

Check warning on line 330 in nireports/reportlets/utils.py

View check run for this annotation

Codecov / codecov/patch

nireports/reportlets/utils.py#L330

Added line #L330 was not covered by tests

try:
in_file = nb.load(in_file)
except AttributeError:
in_file = in_file
if not isinstance(in_file, nb.Nifti1Image):
in_file = load_api(in_file, nb.Nifti1Image)

if len(in_file.shape) == 3:
return in_file

return nlimage.index_img(in_file, 0)


def compose_view(bg_svgs, fg_svgs, ref=0, out_file="report.svg"):
def compose_view(
bg_svgs: list[SVGFigure],
fg_svgs: list[SVGFigure],
ref: int = 0,
out_file: str | os.PathLike[str] = "report.svg",
) -> str:
"""
Compose svgs into one standalone svg with CSS flickering animation.
Expand All @@ -338,9 +368,13 @@ def compose_view(bg_svgs, fg_svgs, ref=0, out_file="report.svg"):
return str(out_file)


def _compose_view(bg_svgs, fg_svgs, ref=0):
def _compose_view(
bg_svgs: list[SVGFigure],
fg_svgs: list[SVGFigure],
ref: int = 0,
) -> list[str]:
from svgutils.compose import Unit
from svgutils.transform import GroupElement, SVGFigure
from svgutils.transform import GroupElement

if fg_svgs is None:
fg_svgs = []
Expand All @@ -350,16 +384,11 @@ def _compose_view(bg_svgs, fg_svgs, ref=0):
roots = [f.getroot() for f in svgs]

# Query the size of each
sizes = []
for f in svgs:
viewbox = [float(v) for v in f.root.get("viewBox").split(" ")]
width = int(viewbox[2])
height = int(viewbox[3])
sizes.append((width, height))
sizes = np.array(
[[int(float(val)) for val in f.root.get("viewBox").split(" ")[2:4]] for f in svgs]
)
nsvgs = len(bg_svgs)

sizes = np.array(sizes)

# Calculate the scale to fit all widths
width = sizes[ref, 0]
scales = width / sizes[:, 0]
Expand Down Expand Up @@ -416,7 +445,7 @@ def _compose_view(bg_svgs, fg_svgs, ref=0):
return svg


def transform_to_2d(data, max_axis):
def transform_to_2d(data: npt.NDArray, max_axis: int) -> npt.NDArray:
"""
Projects 3d data cube along one axis using maximum intensity with
preservation of the signs. Adapted from nilearn.
Expand All @@ -440,7 +469,7 @@ def transform_to_2d(data, max_axis):
return np.rot90(maximum_intensity_data)


def get_parula():
def get_parula() -> mpl.colors.LinearSegmentedColormap:
"""Generate a 'parula' colormap."""
from matplotlib.colors import LinearSegmentedColormap

Expand Down

0 comments on commit 8c18471

Please sign in to comment.