Skip to content

Commit

Permalink
ENH: Add raincloud plot capabilities
Browse files Browse the repository at this point in the history
Add raincloud plot capabilities and demonstrate on the FA values of
computed on a DWI volume across different tissue types.
  • Loading branch information
jhlegarreta committed May 15, 2024
1 parent 60fd854 commit 230a240
Show file tree
Hide file tree
Showing 7 changed files with 337 additions and 9 deletions.
7 changes: 6 additions & 1 deletion nireports/interfaces/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,11 +24,16 @@

from nireports.interfaces.fmri import FMRISummary
from nireports.interfaces.mosaic import PlotContours, PlotMosaic, PlotSpikes
from nireports.interfaces.nuisance import CompCorVariancePlot, ConfoundsCorrelationPlot
from nireports.interfaces.nuisance import (
CompCorVariancePlot,
ConfoundsCorrelationPlot,
RaincloudPlot,
)

__all__ = (
"CompCorVariancePlot",
"ConfoundsCorrelationPlot",
"RaincloudPlot",
"FMRISummary",
"PlotContours",
"PlotMosaic",
Expand Down
65 changes: 64 additions & 1 deletion nireports/interfaces/nuisance.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
)
from nipype.utils.filemanip import fname_presuffix

from nireports.reportlets.nuisance import confounds_correlation_plot
from nireports.reportlets.nuisance import confounds_correlation_plot, plot_raincloud
from nireports.reportlets.xca import compcor_variance_plot


Expand Down Expand Up @@ -142,3 +142,66 @@ def _run_interface(self, runtime):
ignore_initial_volumes=self.inputs.ignore_initial_volumes,
)
return runtime


class _RaincloudPlotInputSpec(BaseInterfaceInputSpec):
data_file = File(exists=True, mandatory=True, desc="File containing the data")
out_file = traits.Either(None, File, value=None, usedefault=True, desc="Path to save plot")
group_label = traits.Str(
"group_label",
mandatory=True,
desc="Group name of interest",
)
feature = traits.Str(
"feature",
mandatory=True,
desc="Feature of interest",
)
palette = traits.Str(
"Set2",
usedefault=True,
desc="Color palette name",
)
orient = traits.Str(
"v",
usedefault=True,
desc="Orientation",
)
density = traits.Bool(
True,
usedefault=True,
desc="``True`` to plot the density",
)


class _RaincloudPlotOutputSpec(TraitedSpec):
out_file = File(exists=True, desc="Path to saved plot")


class RaincloudPlot(SimpleInterface):
"""Plot a raincloud of values."""

input_spec = _RaincloudPlotInputSpec
output_spec = _RaincloudPlotOutputSpec

def _run_interface(self, runtime):
if self.inputs.out_file is None:
self._results["out_file"] = fname_presuffix(
self.inputs.data_file,
suffix="_raincloud.svg",
use_ext=False,
newpath=runtime.cwd,
)
else:
self._results["out_file"] = self.inputs.out_file
plot_raincloud(
data_file=self.inputs.data_file,
group_label=self.inputs.group_label,
feature=self.inputs.feature,
palette=self.inputs.palette,
orient=self.inputs.orient,
density=self.inputs.density,
output_file=self._results["out_file"],
**kwargs,
)
return runtime
12 changes: 12 additions & 0 deletions nireports/reportlets/modality/dwi.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,13 @@
import matplotlib as mpl
import nibabel as nb
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
from mpl_toolkits.mplot3d import art3d
from nilearn.plotting import plot_anat

from nireports.reportlets.nuisance import plot_carpet


def plot_dwi(dataobj, affine, gradient=None, **kwargs):
"""
Expand Down Expand Up @@ -405,3 +408,12 @@ def plot_gradients(
plt.suptitle(title)

return ax


def plot_tissue_values(pvms, fa):
# Plot FA values in voxels for GM, WM, CSF as boxplots
# ToDo
# Define colors for CST, GM, WM or read them from somehwere

df = pd.DataFrame(pvms, columns=["group, score"])
return plot_raincloud(df)
158 changes: 154 additions & 4 deletions nireports/reportlets/nuisance.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import matplotlib as mpl
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.backends.backend_pdf import FigureCanvasPdf as FigureCanvas
from matplotlib.colorbar import ColorbarBase
Expand Down Expand Up @@ -626,8 +627,6 @@ def confoundplot(
cutoff=None,
ylims=None,
):
import seaborn as sns

# Define TR and number of frames
notr = False
if tr is None:
Expand Down Expand Up @@ -856,8 +855,6 @@ def confounds_correlation_plot(
output_file: :obj:`str`
The file where the figure is saved.
"""
import pandas as pd
import seaborn as sns

confounds_data = pd.read_table(confounds_file)

Expand Down Expand Up @@ -932,3 +929,156 @@ def confounds_correlation_plot(
figure = None
return output_file
return [ax0, ax1], gs


def plot_raincloud(
data_file,
group_name,
feature,
palette="Set2",
orient="v",
density=True,
figure=None,
output_file=None,
**kwargs,
):
"""
Generate a raincloud plot with the data points corresponding to the
``feature`` value contained in the data file.
Parameters
----------
data_file : :obj:`str`
File containing the data of interest.
figure : :obj:`matplotlib.pyplot.figure` or None
Existing figure on which to plot.
group_name : :obj:`str`
The group name of interest to be plot.
feature : :obj:`str`
The feature of interest to be plot.
palette : :obj:`str`, optional
Color palette name provided to :func:`sns.stripplot`.
orient : :obj:`str`, optional
Plot orientation (``v`` or ``h``).
density : :obj:`bool`, optional
``True`` to plot the density of the data points.
output_file : :obj:`str` or :obj:`None`
Path where the output figure should be saved. If this is not defined,
then the plotting axes will be returned instead of the saved figure
path.
kwargs : :obj:`dict`
Extra args given to :func:`sns.violinplot`, :func:`sns.stripplot` and
:func:`sns.boxplot`.
Returns
-------
axes and gridspec
Plotting axes and gridspec. Returned only if ``output_file`` is ``None``.
output_file : :obj:`str`
The file where the figure is saved.
"""

# ToDo
# Think how data will come or why not accept the df directly ?
# df = pd.read_table(data_file)
df = pd.read_csv(data_file, sep=r"[\t\s]+", engine="python")

if figure is None:
plt.figure(figsize=(7, 5))

gs = GridSpec(1, 1)
ax = plt.subplot(gs[0, 0])

# ToDo
# Make all the below go to the kwargs. Note that they take different kwargs
# so we will need 3 dictionaries: one fro violinplot, one for stripplot, and
# the other for the boxplot, and one for the general style ?
sns.set(style="whitegrid", font_scale=2)

x = feature
y = group_name
# Swap x/y if the requested orientation is vertical
if orient == "v":
x = group_name
y = feature

# Plot the violin plots
if density:
ax = sns.violinplot(
x=x,
y=y,
data=df,
hue=group_name,
dodge=False,
palette=palette,
scale="width",
inner=None,
orient=orient,
)

# Cut half of the violins
for violin in ax.collections:
bbox = violin.get_paths()[0].get_extents()
x0, y0, width, height = bbox.bounds
width_denom = 2
height_denom = 1
if orient == "h":
width_denom = 1
height_denom = 2
violin.set_clip_path(
plt.Rectangle(
(x0, y0), width / width_denom, height / height_denom, transform=ax.transData
)
)

# Add boxplots
width = 0.15
sns.boxplot(
x=x,
y=y,
data=df,
color="black",
width=width,
zorder=10,
showcaps=True,
boxprops={"facecolor": "none", "zorder": 10},
showfliers=True,
whiskerprops={"linewidth": 2, "zorder": 10},
saturation=1,
orient=orient,
ax=ax,
)

old_len_collections = len(ax.collections)

# Plot the data points as dots
sns.stripplot(
x=x,
y=y,
hue=group_name,
data=df,
palette=palette,
edgecolor="white",
size=3,
jitter=0.1,
zorder=0,
orient=orient,
ax=ax,
)

# Offset the dots that would be otherwise shadowed by the violins
if density:
offset = np.array([width, 0])
if orient == "h":
offset = np.array([0, width])
for dots in ax.collections[old_len_collections:]:
dots.set_offsets(dots.get_offsets() + offset)

if output_file is not None:
figure = plt.gcf()
figure.savefig(output_file, bbox_inches="tight")
plt.close(figure)
figure = None
return output_file

return ax, gs
22 changes: 21 additions & 1 deletion nireports/tests/test_dwi.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import pytest
from matplotlib import pyplot as plt

from nireports.reportlets.modality.dwi import plot_dwi, plot_gradients
from nireports.reportlets.modality.dwi import plot_dwi, plot_gradients, plot_tissue_values


def test_plot_dwi(tmp_path, testdata_path, outdir):
Expand Down Expand Up @@ -69,3 +69,23 @@ def test_plot_gradients(tmp_path, testdata_path, dwi_btable, outdir):

if outdir is not None:
plt.savefig(outdir / f"{dwi_btable}.svg", bbox_inches="tight")


def test_plot_tissue_values():
# Plot FA values in voxels for GM, WM, CSF as raincloud
# df = pd.DataFrame(pvms, columns=["group, score"])
data_file = ""
group_name = "tissue"
feature = "FA"
figure = None
output_file = None
kwargs = {}

plot_tissue_values(
data_file,
group_name,
feature,
figure=figure,
output_file=output_file,
**kwargs,
)
29 changes: 28 additions & 1 deletion nireports/tests/test_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@

import pytest

from nireports.interfaces.nuisance import CompCorVariancePlot, ConfoundsCorrelationPlot
from nireports.interfaces.nuisance import (
CompCorVariancePlot,
ConfoundsCorrelationPlot,
RaincloudPlot,
)


def _smoke_test_report(report_interface, artifact_name):
Expand Down Expand Up @@ -56,3 +60,26 @@ def test_ConfoundsCorrelationPlot(datadir, ignore_initial_volumes):
ignore_initial_volumes=ignore_initial_volumes,
)
_smoke_test_report(cc_rpt, f"confounds_correlation_{ignore_initial_volumes}.svg")


# @pytest.mark.parametrize("orient", ["h", "v"])
# @pytest.mark.parametrize("density", (True, False))
def test_RaincloudPlot(
datadir,
):
"""Raincloud plot report test"""
data_file = os.path.join(datadir, "raincloud_test.tsv")
group_label = "group"
feature = "score"
palette = "Set2"
orient = "v"
density = True
rc_rpt = RaincloudPlot(
data_file=data_file,
group_label=group_label,
feature=feature,
palette=palette,
orient=orient,
density=density,
)
_smoke_test_report(rc_rpt, f"raincloud_{orient}_{density}.svg")
Loading

0 comments on commit 230a240

Please sign in to comment.