From f4709305cc28448a7fce0ae5c5602011d140bf8b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jon=20Haitz=20Legarreta=20Gorro=C3=B1o?= Date: Wed, 8 May 2024 20:18:13 -0400 Subject: [PATCH] ENH: Add raincloud plot capabilities Add raincloud plot capabilities and demonstrate on the FA values of computed on a DWI volume across different tissue types. Add the corresponding tests. Add a `utils.py` helper module to the `tests` module. Contains a function to create random data used by the added tests. --- nireports/interfaces/__init__.py | 7 +- nireports/interfaces/nuisance.py | 108 ++++++++++- nireports/reportlets/modality/dwi.py | 28 +++ nireports/reportlets/nuisance.py | 271 ++++++++++++++++++++++++++- nireports/tests/test_dwi.py | 42 ++++- nireports/tests/test_interfaces.py | 42 ++++- nireports/tests/test_reportlets.py | 105 ++++++++++- nireports/tests/utils.py | 37 ++++ 8 files changed, 631 insertions(+), 9 deletions(-) create mode 100644 nireports/tests/utils.py diff --git a/nireports/interfaces/__init__.py b/nireports/interfaces/__init__.py index e190ef32..7445a1ac 100644 --- a/nireports/interfaces/__init__.py +++ b/nireports/interfaces/__init__.py @@ -24,11 +24,16 @@ from nireports.interfaces.fmri import FMRISummary from nireports.interfaces.mosaic import PlotContours, PlotMosaic, PlotSpikes -from nireports.interfaces.nuisance import CompCorVariancePlot, ConfoundsCorrelationPlot +from nireports.interfaces.nuisance import ( + CompCorVariancePlot, + ConfoundsCorrelationPlot, + RaincloudPlot, +) __all__ = ( "CompCorVariancePlot", "ConfoundsCorrelationPlot", + "RaincloudPlot", "FMRISummary", "PlotContours", "PlotMosaic", diff --git a/nireports/interfaces/nuisance.py b/nireports/interfaces/nuisance.py index 734767f4..18f374ad 100644 --- a/nireports/interfaces/nuisance.py +++ b/nireports/interfaces/nuisance.py @@ -32,7 +32,7 @@ ) from nipype.utils.filemanip import fname_presuffix -from nireports.reportlets.nuisance import confounds_correlation_plot +from nireports.reportlets.nuisance import confounds_correlation_plot, plot_raincloud from nireports.reportlets.xca import compcor_variance_plot @@ -142,3 +142,109 @@ def _run_interface(self, runtime): ignore_initial_volumes=self.inputs.ignore_initial_volumes, ) return runtime + + +class _RaincloudPlotInputSpec(BaseInterfaceInputSpec): + data_file = File(exists=True, mandatory=True, desc="File containing the data") + out_file = traits.Either(None, File, value=None, usedefault=True, desc="Path to save plot") + group_name = traits.Str( + "group_name", + mandatory=True, + desc="Group name of interest", + ) + feature = traits.Str( + "feature", + mandatory=True, + desc="Feature of interest", + ) + palette = traits.Str( + "Set2", + usedefault=True, + desc="Color palette name", + ) + orient = traits.Str( + "v", + usedefault=True, + desc="Orientation", + ) + density = traits.Bool( + True, + usedefault=True, + desc="``True`` to plot the density", + ) + upper_limit_value = traits.Float( + None, + usedefault=True, + desc="Upper limit value over which any value in the data will be styled " + "with a different style", + ) + upper_limit_color = traits.Str( + "gray", + usedefault=True, + desc="Lower limit value under which any value in the data will be styled " + "with a different style", + ) + lower_limit_value = traits.Float( + None, + usedefault=True, + desc="", + ) + lower_limit_color = traits.Str( + "gray", + usedefault=True, + desc="Color name to represent values under ``lower_limit_value``", + ) + limit_offset = traits.Float( + None, + usedefault=True, + desc="Offset to plot the values over/under the upper/lower limit values", + ) + mark_nans = traits.Bool( + True, + usedefault=True, + desc="``True`` to plot NaNs as dots. ``nans_values`` must be provided if True", + ) + nans_value = traits.Float( + None, + usedefault=True, + desc="Value to use for NaN values`", + ) + nans_color = traits.Str( + "black", + usedefault=True, + desc="Color name to represent NaN values", + ) + + +class _RaincloudPlotOutputSpec(TraitedSpec): + out_file = File(exists=True, desc="Path to saved plot") + + +class RaincloudPlot(SimpleInterface): + """Plot a raincloud of values.""" + + input_spec = _RaincloudPlotInputSpec + output_spec = _RaincloudPlotOutputSpec + + def _run_interface(self, runtime, **kwargs): + if self.inputs.out_file is None: + self._results["out_file"] = fname_presuffix( + self.inputs.data_file, + suffix="_raincloud.svg", + use_ext=False, + newpath=runtime.cwd, + ) + else: + self._results["out_file"] = self.inputs.out_file + plot_raincloud( + data_file=self.inputs.data_file, + group_name=self.inputs.group_name, + feature=self.inputs.feature, + palette=self.inputs.palette, + orient=self.inputs.orient, + density=self.inputs.density, + mark_nans=self.inputs.mark_nans, + output_file=self._results["out_file"], + **kwargs, + ) + return runtime diff --git a/nireports/reportlets/modality/dwi.py b/nireports/reportlets/modality/dwi.py index 117b16f5..a42ce57b 100644 --- a/nireports/reportlets/modality/dwi.py +++ b/nireports/reportlets/modality/dwi.py @@ -29,6 +29,8 @@ from mpl_toolkits.mplot3d import art3d from nilearn.plotting import plot_anat +from nireports.reportlets.nuisance import plot_raincloud + def plot_dwi(dataobj, affine, gradient=None, **kwargs): """ @@ -405,3 +407,29 @@ def plot_gradients( plt.suptitle(title) return ax + + +def plot_tissue_values(data_file, group_name, feature, **kwargs): + """Generate a raincloud plot with the data points corresponding to the + ``feature`` value contained in the data file. + + Parameters + ---------- + data_file : :obj:`str` + File containing the data of interest. + group_name : :obj:`str` + The group name of interest to be plot. + feature : :obj:`str` + The feature of interest to be plot. + kwargs : :obj:`dict` + Extra args given to :func:~`nireports.reportlets.nuisance.plot_raincloud`. + + Returns + ------- + axes and gridspec + Plotting axes and gridspec. Returned only if ``output_file`` is ``None``. + output_file : :obj:`str` + The file where the figure is saved. + """ + + return plot_raincloud(data_file, group_name, feature, **kwargs) diff --git a/nireports/reportlets/nuisance.py b/nireports/reportlets/nuisance.py index 444104ec..c4e4e7c5 100644 --- a/nireports/reportlets/nuisance.py +++ b/nireports/reportlets/nuisance.py @@ -25,11 +25,13 @@ """Plotting distributions.""" import math +import operator import os.path as op import matplotlib as mpl import matplotlib.pyplot as plt import numpy as np +import pandas as pd import seaborn as sns from matplotlib.backends.backend_pdf import FigureCanvasPdf as FigureCanvas from matplotlib.colorbar import ColorbarBase @@ -626,8 +628,6 @@ def confoundplot( cutoff=None, ylims=None, ): - import seaborn as sns - # Define TR and number of frames notr = False if tr is None: @@ -856,8 +856,6 @@ def confounds_correlation_plot( output_file: :obj:`str` The file where the figure is saved. """ - import pandas as pd - import seaborn as sns confounds_data = pd.read_table(confounds_file) @@ -932,3 +930,268 @@ def confounds_correlation_plot( figure = None return output_file return [ax0, ax1], gs + + +def _plot_density(x, y, df, group_name, palette, orient): + ax = sns.violinplot( + x=x, + y=y, + data=df, + hue=group_name, + dodge=False, + palette=palette, + density_norm="width", + inner=None, + orient=orient, + ) + + # Cut half of the violins + for violin in ax.collections: + bbox = violin.get_paths()[0].get_extents() + x0, y0, width, height = bbox.bounds + width_denom = 2 + height_denom = 1 + if orient == "h": + width_denom = 1 + height_denom = 2 + violin.set_clip_path( + plt.Rectangle( + (x0, y0), width / width_denom, height / height_denom, transform=ax.transData + ) + ) + + return ax + + +def _jitter_data_points(old_len_collections, orient, width, ax): + offset = np.array([width, 0]) + if orient == "h": + offset = np.array([0, width]) + for dots in ax.collections[old_len_collections:]: + dots.set_offsets(dots.get_offsets() + offset) + + +def _plot_nans(df, x, y, color, orient, ax): + df_nans = df[df.isna().any(axis=1)] + sns.stripplot( + x=x, + y=y, + data=df_nans, + color=color, + orient=orient, + ax=ax, + ) + + +def _plot_out_of_range( + df, + x, + feature, + orient, + limit_offset, + limit_value, + limit_color, + limit_name, + color_vble_name, + _op, + ax, +): + if limit_color is None: + raise ValueError( + f"``{color_vble_name}`` must be provided if ``{limit_name}`` is provided." + ) + if limit_offset is None: + raise ValueError(f"``limit_offset`` must be provided if ``{limit_name}`` is provided.") + if _op == operator.gt: + arithm = operator.add + elif _op == operator.lt: + arithm = operator.sub + else: + raise ValueError(f"``{_op}`` must be either ``gt`` or ``lt``.") + + df_overflow = df[_op(df[feature], limit_value)] + sns.stripplot( + x=x, + y=arithm(limit_value, limit_offset), + data=df_overflow, + color=limit_color, + orient=orient, + ax=ax, + ) + + +def plot_raincloud( + data_file, + group_name, + feature, + palette="Set2", + orient="v", + density=True, + upper_limit_value=None, + upper_limit_color="gray", + lower_limit_value=None, + lower_limit_color="gray", + limit_offset=None, + mark_nans=True, + nans_value=None, + nans_color="black", + figure=None, + output_file=None, +): + """ + Generate a raincloud plot with the data points corresponding to the + ``feature`` value contained in the data file. + + Parameters + ---------- + data_file : :obj:`str` + File containing the data of interest. + figure : :obj:`matplotlib.pyplot.figure` or None + Existing figure on which to plot. + group_name : :obj:`str` + The group name of interest to be plot. + feature : :obj:`str` + The feature of interest to be plot. + palette : :obj:`str`, optional + Color palette name provided to :func:`sns.stripplot`. + orient : :obj:`str`, optional + Plot orientation (``v`` or ``h``). + density : :obj:`bool`, optional + ``True`` to plot the density of the data points. + upper_limit_value : :obj:`float`, optional + Upper limit value over which any value in the data will be styled with a + different style. + upper_limit_color : :obj:`str`, optional + Color name to represent values over ``upper_limit_value``. + lower_limit_value : :obj:`float`, optional + Lower limit value under which any value in the data will be styled with + a different style. + lower_limit_color : :obj:`str`, optional + Color name to represent values under ``lower_limit_value``. + limit_offset : :obj:`float`, optional + Offset to plot the values over/under the upper/lower limit values. + mark_nans : :obj:`bool`, optional + ``True`` to plot NaNs as dots. ``nans_values`` must be provided if True. + nans_value : :obj:`float`, optional + Value to use for NaN values. + nans_color : :obj:`str`, optional + Color name to represent NaN values. + output_file : :obj:`str` or :obj:`None` + Path where the output figure should be saved. If this is not defined, + then the plotting axes will be returned instead of the saved figure + path. + + Returns + ------- + axes and gridspec + Plotting axes and gridspec. Returned only if ``output_file`` is ``None``. + output_file : :obj:`str` + The file where the figure is saved. + """ + + df = pd.read_csv(data_file, sep=r"[\t\s]+", engine="python") + + if figure is None: + plt.figure(figsize=(7, 5)) + + gs = GridSpec(1, 1) + ax = plt.subplot(gs[0, 0]) + + sns.set(style="white", font_scale=2) + + x = feature + y = group_name + # Swap x/y if the requested orientation is vertical + if orient == "v": + x = group_name + y = feature + + # Plot the density + if density: + ax = _plot_density(x, y, df, group_name, palette, orient) + + # Add boxplots + width = 0.15 + sns.boxplot( + x=x, + y=y, + data=df, + color="black", + width=width, + zorder=10, + showcaps=True, + boxprops={"facecolor": "none", "zorder": 10}, + showfliers=True, + whiskerprops={"linewidth": 2, "zorder": 10}, + saturation=1, + orient=orient, + ax=ax, + ) + + old_len_collections = len(ax.collections) + + # Plot the data points as dots + sns.stripplot( + x=x, + y=y, + hue=group_name, + data=df, + palette=palette, + edgecolor="white", + size=3, + jitter=0.1, + zorder=0, + orient=orient, + ax=ax, + ) + + # Offset the dots that would be otherwise shadowed by the violins + if density: + _jitter_data_points(old_len_collections, orient, width, ax) + + # Draw nans if any + if mark_nans: + if nans_value is None: + raise ValueError("``nans_value`` must be provided if ``mark_nans`` is True.") + _plot_nans(df, x, nans_value, nans_color, orient, ax) + + # If upper/lower limits are provided, draw the points with a different color + if upper_limit_value is not None: + _plot_out_of_range( + df, + x, + feature, + orient, + limit_offset, + upper_limit_value, + upper_limit_color, + "upper_limit_value", + "upper_limit_color", + operator.gt, + ax, + ) + + if lower_limit_value is not None: + _plot_out_of_range( + df, + x, + feature, + orient, + limit_offset, + lower_limit_value, + lower_limit_color, + "lower_limit_value", + "lower_limit_color", + operator.lt, + ax, + ) + + if output_file is not None: + figure = plt.gcf() + plt.tight_layout() + figure.savefig(output_file, bbox_inches="tight") + plt.close(figure) + figure = None + return output_file + + return ax, gs diff --git a/nireports/tests/test_dwi.py b/nireports/tests/test_dwi.py index 6e43fea4..32a32e08 100644 --- a/nireports/tests/test_dwi.py +++ b/nireports/tests/test_dwi.py @@ -27,7 +27,8 @@ import pytest from matplotlib import pyplot as plt -from nireports.reportlets.modality.dwi import plot_dwi, plot_gradients +from nireports.reportlets.modality.dwi import plot_dwi, plot_gradients, plot_tissue_values +from nireports.tests.utils import _generate_raincloud_random_data def test_plot_dwi(tmp_path, testdata_path, outdir): @@ -69,3 +70,42 @@ def test_plot_gradients(tmp_path, testdata_path, dwi_btable, outdir): if outdir is not None: plt.savefig(outdir / f"{dwi_btable}.svg", bbox_inches="tight") + + +def test_plot_tissue_values(tmp_path): + features_label = "fa" + group_label = "tissue" + group_names = ["CSF", "GM", "WM"] + min_val_csf = 0.0 + max_val_csf = 0.2 + min_max_csf = (min_val_csf, max_val_csf) + min_val_gm = 0.0 + max_val_gm = 0.6 + min_max_gm = (min_val_gm, max_val_gm) + min_val_wm = 0.3 + max_val_wm = 1.0 + min_max_wm = (min_val_wm, max_val_wm) + min_max = [min_max_csf, min_max_gm, min_max_wm] + n_grp_samples = 250 + data_file = tmp_path / "tissue_fa.tsv" + + _generate_raincloud_random_data( + min_max, n_grp_samples, features_label, group_label, group_names, data_file + ) + + palette = "Set2" + orient = "v" + density = True + output_file = tmp_path / "tissue_fa.png" + mark_nans = False + + plot_tissue_values( + data_file, + group_label, + features_label, + palette=palette, + orient=orient, + density=density, + mark_nans=mark_nans, + output_file=output_file, + ) diff --git a/nireports/tests/test_interfaces.py b/nireports/tests/test_interfaces.py index d361dd6b..d027c0e1 100644 --- a/nireports/tests/test_interfaces.py +++ b/nireports/tests/test_interfaces.py @@ -27,7 +27,12 @@ import pytest -from nireports.interfaces.nuisance import CompCorVariancePlot, ConfoundsCorrelationPlot +from nireports.interfaces.nuisance import ( + CompCorVariancePlot, + ConfoundsCorrelationPlot, + RaincloudPlot, +) +from nireports.tests.utils import _generate_raincloud_random_data def _smoke_test_report(report_interface, artifact_name): @@ -56,3 +61,38 @@ def test_ConfoundsCorrelationPlot(datadir, ignore_initial_volumes): ignore_initial_volumes=ignore_initial_volumes, ) _smoke_test_report(cc_rpt, f"confounds_correlation_{ignore_initial_volumes}.svg") + + +@pytest.mark.parametrize("orient", ["h", "v"]) +@pytest.mark.parametrize("density", (True, False)) +def test_RaincloudPlot(orient, density, tmp_path): + """Raincloud plot report test""" + features_label = "value" + group_label = "group" + group_names = ["group1", "group2"] + min_val_grp1 = 0.3 + max_val_grp1 = 1.0 + min_max_group1 = (min_val_grp1, max_val_grp1) + min_val_grp2 = 0.0 + max_val_grp2 = 0.6 + min_max_group2 = (min_val_grp2, max_val_grp2) + min_max = [min_max_group1, min_max_group2] + n_grp_samples = 250 + data_file = tmp_path / "data.tsv" + + _generate_raincloud_random_data( + min_max, n_grp_samples, features_label, group_label, group_names, data_file + ) + + palette = "Set2" + mark_nans = False + rc_rpt = RaincloudPlot( + data_file=data_file, + group_name=group_label, + feature=features_label, + palette=palette, + orient=orient, + density=density, + mark_nans=mark_nans, + ) + _smoke_test_report(rc_rpt, f"raincloud_orient-{orient}_density-{density}.svg") diff --git a/nireports/tests/test_reportlets.py b/nireports/tests/test_reportlets.py index 671334e0..3ffe682c 100644 --- a/nireports/tests/test_reportlets.py +++ b/nireports/tests/test_reportlets.py @@ -35,9 +35,10 @@ from nireports.reportlets.modality.func import fMRIPlot from nireports.reportlets.mosaic import plot_mosaic -from nireports.reportlets.nuisance import plot_carpet +from nireports.reportlets.nuisance import plot_carpet, plot_raincloud from nireports.reportlets.surface import cifti_surfaces_plot from nireports.reportlets.xca import compcor_variance_plot, plot_melodic_components +from nireports.tests.utils import _generate_raincloud_random_data from nireports.tools.timeseries import cifti_timeseries as _cifti_timeseries from nireports.tools.timeseries import get_tr as _get_tr from nireports.tools.timeseries import nifti_timeseries as _nifti_timeseries @@ -368,3 +369,105 @@ def test_mriqc_plot_mosaic_2(tmp_path, testdata_path, outdir): maxrows=12, annotate=True, ) + + +@pytest.mark.parametrize("orient", ["h", "v"]) +@pytest.mark.parametrize("density", (True, False)) +def test_plot_raincloud(orient, density, tmp_path): + features_label = "value" + group_label = "group" + group_names = ["group1", "group2"] + min_val_grp1 = 0.3 + max_val_grp1 = 1.0 + min_max_group1 = (min_val_grp1, max_val_grp1) + min_val_grp2 = 0.0 + max_val_grp2 = 0.6 + min_max_group2 = (min_val_grp2, max_val_grp2) + min_max = [min_max_group1, min_max_group2] + n_grp_samples = 250 + data_file = tmp_path / "data.tsv" + + _generate_raincloud_random_data( + min_max, n_grp_samples, features_label, group_label, group_names, data_file + ) + + palette = "Set2" + mark_nans = False + output_file = tmp_path / f"raincloud_reg_orient-{orient}_density-{density}.png" + + plot_raincloud( + data_file, + group_label, + features_label, + palette=palette, + orient=orient, + density=density, + mark_nans=mark_nans, + output_file=output_file, + ) + + group_nans = [50, 0] + + _generate_raincloud_random_data( + min_max, + n_grp_samples, + features_label, + group_label, + group_names, + data_file, + group_nans=group_nans, + ) + + mark_nans = True + nans_value = 2.0 + output_file = tmp_path / f"raincloud_nans_orient-{orient}_density-{density}.png" + + plot_raincloud( + data_file, + group_label, + features_label, + palette=palette, + orient=orient, + density=density, + mark_nans=mark_nans, + nans_value=nans_value, + output_file=output_file, + ) + + min_val_grp1 = 0.3 + max_val_grp1 = 1.2 + min_max_group1 = (min_val_grp1, max_val_grp1) + min_val_grp2 = -0.2 + max_val_grp2 = 0.6 + min_max_group2 = (min_val_grp2, max_val_grp2) + min_max = [min_max_group1, min_max_group2] + + _generate_raincloud_random_data( + min_max, + n_grp_samples, + features_label, + group_label, + group_names, + data_file, + group_nans=group_nans, + ) + + upper_limit_value = 1.0 + lower_limit_value = 0.0 + limit_offset = 0.5 + output_file = tmp_path / f"raincloud_nans_limits_orient-{orient}_density-{density}.png" + + plot_raincloud( + data_file, + group_label, + features_label, + palette=palette, + orient=orient, + density=density, + upper_limit_value=upper_limit_value, + lower_limit_value=lower_limit_value, + limit_offset=limit_offset, + mark_nans=mark_nans, + nans_value=nans_value, + output_file=output_file, + ) diff --git a/nireports/tests/utils.py b/nireports/tests/utils.py new file mode 100644 index 00000000..4bcd61d7 --- /dev/null +++ b/nireports/tests/utils.py @@ -0,0 +1,37 @@ +import numpy as np +import pandas as pd + + +def _generate_raincloud_random_data( + min_max, + n_grp_samples, + features_label, + group_label, + group_names, + data_file, + group_nans=None, +): + rng = np.random.default_rng(1234) + + if group_nans is None: + group_nans = [None] * len(min_max) + + # Create some random data in the [min_val, max_val) half-open interval + values = np.array([]) + names = [] + for group_min_max, name, nans in zip(min_max, group_names, group_nans): + min_val = group_min_max[0] + max_val = group_min_max[1] + range_size = max_val - min_val + _values = rng.random(n_grp_samples) * range_size + min_val + + values = np.concatenate((values, _values), axis=0) + names.extend([name] * n_grp_samples) + + if nans: + values = np.concatenate((values, [np.nan] * nans), axis=0) + names.extend([name] * nans) + + df = pd.DataFrame(np.vstack([values, names]).T, columns=[features_label, group_label]) + + df.to_csv(data_file, sep="\t")