Skip to content

Commit

Permalink
ENH: Add DWI volume plot method
Browse files Browse the repository at this point in the history
Add DWI volume plot method.
  • Loading branch information
jhlegarreta committed May 6, 2024
1 parent 0924b31 commit a88f6eb
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 1 deletion.
34 changes: 34 additions & 0 deletions nireports/reportlets/modality/dwi.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,44 @@
# https://www.nipreps.org/community/licensing/
#
"""Visualizations for diffusion MRI data."""
import nibabel as nb
import numpy as np
from matplotlib import pyplot as plt
from matplotlib.pyplot import cm
from mpl_toolkits.mplot3d import art3d
from nilearn.plotting import plot_anat


def plot_dwi(dataobj, affine, gradient=None, **kwargs):
"""Plot a DW map."""

plt.rcParams.update(
{
"text.usetex": True,
"font.family": "sans-serif",
"font.sans-serif": ["Helvetica"],
}
)

affine = np.diag(nb.affines.voxel_sizes(affine).tolist() + [1])
affine[:3, 3] = -1.0 * (affine[:3, :3] @ ((np.array(dataobj.shape) - 1) * 0.5))

vmax = kwargs.pop("vmax", None) or np.percentile(dataobj, 98)
cut_coords = kwargs.pop("cut_coords", None) or (0, 0, 0)

return plot_anat(
nb.Nifti1Image(dataobj, affine, None),
vmax=vmax,
cut_coords=cut_coords,
title=(
r"Reference $b$=0"
if gradient is None
else f"""\
$b$={gradient[3].astype(int)}, \
$\\vec{{b}}$ = ({', '.join(str(v) for v in gradient[:3])})"""
),
**kwargs,
)


def plot_heatmap(
Expand Down
29 changes: 28 additions & 1 deletion nireports/tests/test_dwi.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,38 @@
"""Test DWI reportlets."""

import pytest
from pathlib import Path

import nibabel as nb
import numpy as np
from matplotlib import pyplot as plt

from nireports.reportlets.modality.dwi import plot_gradients
from nireports.reportlets.modality.dwi import plot_dwi, plot_gradients


@pytest.mark.parametrize(
'dwi', 'dwi_btable',
['ds000114_sub-01_ses-test_dwi.nii.gz', 'ds000114_singleshell'],
)
def test_plot_dwi(tmp_path, testdata_path, dwi, dwi_btable, outdir):
"""Check the plot of DWI data."""

dwi_img = nb.load(testdata_path / f'{dwi}')
affine = dwi_img.affine

bvecs = np.loadtxt(testdata_path / f'{dwi_btable}.bvec').T
bvals = np.loadtxt(testdata_path / f'{dwi_btable}.bval')

gradients = np.hstack([bvecs, bvals[:, None]])

# Pick a random volume to show
rng = np.random.default_rng(1234)
idx = rng.integers(low=0, high=len(bvals), size=1).item()

_ = plot_dwi(dwi_img.get_fdata()[..., idx], affine, gradient=gradients[idx])

if outdir is not None:
plt.savefig(outdir / f'{Path(dwi).with_suffix("").stem}.svg', bbox_inches='tight')


@pytest.mark.parametrize(
Expand Down

0 comments on commit a88f6eb

Please sign in to comment.