From a88f6eb5d84a574d91d4b61728d58f6d567e1211 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Mon, 6 May 2024 18:55:13 -0400 Subject: [PATCH] ENH: Add DWI volume plot method Add DWI volume plot method. --- nireports/reportlets/modality/dwi.py | 34 ++++++++++++++++++++++++++++ nireports/tests/test_dwi.py | 29 +++++++++++++++++++++++- 2 files changed, 62 insertions(+), 1 deletion(-) diff --git a/nireports/reportlets/modality/dwi.py b/nireports/reportlets/modality/dwi.py index d28727d4..94cb5729 100644 --- a/nireports/reportlets/modality/dwi.py +++ b/nireports/reportlets/modality/dwi.py @@ -21,10 +21,44 @@ # https://www.nipreps.org/community/licensing/ # """Visualizations for diffusion MRI data.""" +import nibabel as nb import numpy as np from matplotlib import pyplot as plt from matplotlib.pyplot import cm from mpl_toolkits.mplot3d import art3d +from nilearn.plotting import plot_anat + + +def plot_dwi(dataobj, affine, gradient=None, **kwargs): + """Plot a DW map.""" + + plt.rcParams.update( + { + "text.usetex": True, + "font.family": "sans-serif", + "font.sans-serif": ["Helvetica"], + } + ) + + affine = np.diag(nb.affines.voxel_sizes(affine).tolist() + [1]) + affine[:3, 3] = -1.0 * (affine[:3, :3] @ ((np.array(dataobj.shape) - 1) * 0.5)) + + vmax = kwargs.pop("vmax", None) or np.percentile(dataobj, 98) + cut_coords = kwargs.pop("cut_coords", None) or (0, 0, 0) + + return plot_anat( + nb.Nifti1Image(dataobj, affine, None), + vmax=vmax, + cut_coords=cut_coords, + title=( + r"Reference $b$=0" + if gradient is None + else f"""\ +$b$={gradient[3].astype(int)}, \ +$\\vec{{b}}$ = ({', '.join(str(v) for v in gradient[:3])})""" + ), + **kwargs, + ) def plot_heatmap( diff --git a/nireports/tests/test_dwi.py b/nireports/tests/test_dwi.py index 830e3bc0..84257532 100644 --- a/nireports/tests/test_dwi.py +++ b/nireports/tests/test_dwi.py @@ -23,11 +23,38 @@ """Test DWI reportlets.""" import pytest +from pathlib import Path +import nibabel as nb import numpy as np from matplotlib import pyplot as plt -from nireports.reportlets.modality.dwi import plot_gradients +from nireports.reportlets.modality.dwi import plot_dwi, plot_gradients + + +@pytest.mark.parametrize( + 'dwi', 'dwi_btable', + ['ds000114_sub-01_ses-test_dwi.nii.gz', 'ds000114_singleshell'], +) +def test_plot_dwi(tmp_path, testdata_path, dwi, dwi_btable, outdir): + """Check the plot of DWI data.""" + + dwi_img = nb.load(testdata_path / f'{dwi}') + affine = dwi_img.affine + + bvecs = np.loadtxt(testdata_path / f'{dwi_btable}.bvec').T + bvals = np.loadtxt(testdata_path / f'{dwi_btable}.bval') + + gradients = np.hstack([bvecs, bvals[:, None]]) + + # Pick a random volume to show + rng = np.random.default_rng(1234) + idx = rng.integers(low=0, high=len(bvals), size=1).item() + + _ = plot_dwi(dwi_img.get_fdata()[..., idx], affine, gradient=gradients[idx]) + + if outdir is not None: + plt.savefig(outdir / f'{Path(dwi).with_suffix("").stem}.svg', bbox_inches='tight') @pytest.mark.parametrize(