Skip to content

Commit

Permalink
enh: flexibilize views of plot_mosaic to render nonhuman imaging
Browse files Browse the repository at this point in the history
Makes ``plot_mosaic`` more configurable to decide what view(s) are to be
plotted, and which one will be the main view.

Related: nipreps/mriqc#1027.
Resolves: #1.

Co-authored-by: Eilidh MacNicol <[email protected]>
  • Loading branch information
oesteban and eilidhmacnicol committed Mar 9, 2023
1 parent fcd724c commit a8a99db
Show file tree
Hide file tree
Showing 3 changed files with 145 additions and 25 deletions.
28 changes: 28 additions & 0 deletions nireports/interfaces/mosaic.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,29 @@ def _run_interface(self, runtime):
class _PlotMosaicInputSpec(_PlotBaseInputSpec):
bbox_mask_file = File(exists=True, desc="brain mask")
only_noise = traits.Bool(False, desc="plot only noise")
main_view = traits.Enum(
"axial",
"sagittal",
"coronal",
default="axial",
usedefault=True,
)
addon_view1 = traits.Enum(
"sagittal",
"axial",
"coronal",
None,
default="sagittal",
usedefault=True,
)
addon_view2 = traits.Enum(
None,
"axial",
"sagittal",
"coronal",
default=None,
usedefault=True,
)


class _PlotMosaicOutputSpec(TraitedSpec):
Expand Down Expand Up @@ -144,6 +167,11 @@ def _run_interface(self, runtime):
bbox_mask_file=mask,
cmap=self.inputs.cmap,
annotate=self.inputs.annotate,
views=(
self.inputs.main_view,
self.inputs.addon_view1,
self.inputs.addon_view2,
)
)
self._results["out_file"] = str((Path(runtime.cwd) / self.inputs.out_file).resolve())
return runtime
Expand Down
105 changes: 80 additions & 25 deletions nireports/reportlets/mosaic.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,9 +495,18 @@ def plot_mosaic(
plot_sagittal=True,
fig=None,
zmax=128,
views=("axial", "sagittal", None),
):
"""Plot a mosaic of 2D cuts."""

if isinstance(img, (str, bytes)):
VIEW_AXES_ORDER = (2, 1, 0)

# Error with inconsistent views input
print(views)
if views[0] is None or ((views[1] is None) and (views[2] is not None)):
raise RuntimeError("First view must not be None")

if not hasattr(img, "shape"):
nii = nb.as_closest_canonical(nb.load(img))
img_data = nii.get_fdata()
zooms = nii.header.get_zooms()
Expand All @@ -506,20 +515,43 @@ def plot_mosaic(
zooms = [1.0, 1.0, 1.0]
out_file = "mosaic.svg"

if views[1] is None and plot_sagittal:
views = (views[0], "sagittal", None)

# Select the axis through which we cut the planes
axes_order = [
["sagittal", "coronal", "axial"].index(views[0]),
["sagittal", "coronal", "axial"].index(views[1] or "sagittal"),
]

# If 3D, complete last axis
if img_data.ndim > 3:
raise RuntimeError("Dataset has more than three dimensions")
elif img_data.ndim == 3:
axes_order += list(set(range(3)) - set(axes_order))

# Remove extra dimensions
img_data = np.squeeze(img_data)
img_data = np.moveaxis(
np.squeeze(img_data),
axes_order,
VIEW_AXES_ORDER[:len(axes_order)],
)

if img_data.shape[2] > zmax and bbox_mask_file is None:
# Create mask for bounding box
if bbox_mask_file is not None:
bbox_data = np.moveaxis(
nb.as_closest_canonical(nb.load(bbox_mask_file)).get_fdata(),
axes_order,
VIEW_AXES_ORDER[:len(axes_order)],
)
img_data = _bbox(img_data, bbox_data)
elif img_data.shape[-1] > zmax:
lowthres = np.percentile(img_data, 5)
mask_file = np.ones_like(img_data)
mask_file[img_data <= lowthres] = 0
img_data = _bbox(img_data, mask_file)

if bbox_mask_file is not None:
bbox_data = nb.as_closest_canonical(nb.load(bbox_mask_file)).get_fdata()
img_data = _bbox(img_data, bbox_data)

z_vals = np.array(list(range(0, img_data.shape[2])))
z_vals = np.arange(0, img_data.shape[-1], dtype=int)

# Reduce the number of slices shown
if len(z_vals) > zmax:
Expand All @@ -539,12 +571,15 @@ def plot_mosaic(
z_vals = z_vals[::2]

n_images = len(z_vals)
nrows = math.ceil(n_images / ncols)
if plot_sagittal:
nrows += 1
extra_rows = sum(bool(v) for v in views[1:])
nrows = math.ceil(n_images / ncols) + extra_rows

if overlay_mask:
overlay_data = nb.as_closest_canonical(nb.load(overlay_mask)).get_fdata()
overlay_data = np.moveaxis(
nb.as_closest_canonical(nb.load(overlay_mask)).get_fdata(),
axes_order,
VIEW_AXES_ORDER[:len(axes_order)],
)

# create figures
if fig is None:
Expand All @@ -556,20 +591,22 @@ def plot_mosaic(
if not vmax:
vmax = est_vmax

slice_spacing = [vs for i, vs in enumerate(zooms) if i != axes_order[0]]
naxis = 1
for z_val in z_vals:
ax = fig.add_subplot(nrows, ncols, naxis)

if overlay_mask:
ax.set_rasterized(True)

plot_slice(
img_data[:, :, z_val],
vmin=vmin,
vmax=vmax,
cmap=cmap,
ax=ax,
spacing=zooms[:2],
label="%d" % z_val,
spacing=slice_spacing,
label=f"{z_val:d}",
annotate=annotate,
)

Expand All @@ -586,31 +623,49 @@ def plot_mosaic(
vmax=1,
cmap=msk_cmap,
ax=ax,
spacing=zooms[:2],
spacing=slice_spacing,
)
naxis += 1

if plot_sagittal:
naxis = ncols * (nrows - 1) + 1
if views[1] is not None:
slice_spacing = [vs for i, vs in enumerate(zooms) if i != axes_order[1]]
naxis = ncols * (nrows - extra_rows) + 1
step = max(int(img_data.shape[-2] / (ncols + 1)), 1)
start = step
stop = img_data.shape[-2] - step

for slice_val in list(range(start, stop, step))[:ncols]:
ax = fig.add_subplot(nrows, ncols, naxis)

step = int(img_data.shape[0] / (ncols + 1))
plot_slice(
img_data[:, slice_val, :],
vmin=vmin,
vmax=vmax,
cmap=cmap,
ax=ax,
label=f"{slice_val:d}",
spacing=slice_spacing,
)
naxis += 1

if views[1] is not None and views[2] is not None:
slice_spacing = [vs for i, vs in enumerate(zooms) if i != axes_order[2]]
naxis = ncols * (nrows - extra_rows) + 1
step = max(int(img_data.shape[0] / (ncols + 1)), 1)
start = step
stop = img_data.shape[0] - step

if step == 0:
step = 1

for x_val in list(range(start, stop, step))[:ncols]:
for slice_val in list(range(start, stop, step))[:ncols]:
ax = fig.add_subplot(nrows, ncols, naxis)

plot_slice(
img_data[x_val, ...],
img_data[slice_val, ...],
vmin=vmin,
vmax=vmax,
cmap=cmap,
ax=ax,
label="%d" % x_val,
spacing=[zooms[0], zooms[2]],
label=f"{slice_val:d}",
spacing=slice_spacing,
)
naxis += 1

Expand Down
37 changes: 37 additions & 0 deletions nireports/tests/test_reportlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
"""Test reportlets module."""
import os
from pathlib import Path
from itertools import permutations
from functools import partial

import nibabel as nb
import numpy as np
Expand All @@ -32,6 +34,7 @@
from nireports.reportlets.modality.func import fMRIPlot
from nireports.reportlets.nuisance import plot_carpet
from nireports.reportlets.surface import cifti_surfaces_plot
from nireports.reportlets.mosaic import plot_mosaic
from nireports.reportlets.xca import compcor_variance_plot, plot_melodic_components
from nireports.tools.timeseries import cifti_timeseries as _cifti_timeseries
from nireports.tools.timeseries import get_tr as _get_tr
Expand Down Expand Up @@ -321,3 +324,37 @@ def test_nifti_carpetplot(tmp_path, testdata_path, outdir):
output_file=outdir / "carpetplot_nifti.svg" if outdir is not None else None,
drop_trs=0,
)


_views = (
list(permutations(("axial", "sagittal", "coronal", None), 3))
+ [(v, None, None) for v in ("axial", "sagittal", "coronal")]
)


@pytest.mark.parametrize("views", _views)
@pytest.mark.parametrize("plot_sagittal", (True, False))
@pytest.mark.parametrize("only_plot_noise", (True, False))
def test_mriqc_plot_mosaic(tmp_path, testdata_path, outdir, views, plot_sagittal, only_plot_noise):
"""Exercise the generation of mosaics."""

out_file = (
outdir / f"mosaic_{'_'.join(views)}_{plot_sagittal:d}_{only_plot_noise:d}.svg"
) if outdir is not None else None

testfunc = partial(
plot_mosaic,
testdata_path / "testSpatialNormalizationRPTMovingWarpedImage.nii.gz",
views=views,
out_file=out_file,
title=(
f"A mosaic plotting example: views={views}, plot_sagittal={plot_sagittal}",
f"only_plot_noise={only_plot_noise}"
),
)

if views[0] is None or ((views[1] is None) and (views[2] is not None)):
with pytest.raises(RuntimeError):
testfunc()
else:
testfunc()

0 comments on commit a8a99db

Please sign in to comment.