From bab640ac6403b8b5167e790c437f2e9d55446d4d Mon Sep 17 00:00:00 2001 From: Oscar Esteban Date: Sun, 25 Aug 2024 01:08:25 +0200 Subject: [PATCH] fix: address memory issues and corruption in fMRIPlot At the same time, this PR ensures that all calls to ``plt.savefig()`` are followed by closing the figure. Resolves: #129. Closes: #130. --- nireports/assembler/tests/test_report.py | 2 ++ nireports/interfaces/dmri.py | 2 ++ nireports/interfaces/fmri.py | 6 +++--- nireports/reportlets/modality/func.py | 24 +++++++++++++++++------- nireports/reportlets/nuisance.py | 4 ++-- nireports/reportlets/surface.py | 2 +- nireports/reportlets/utils.py | 1 + nireports/reportlets/xca.py | 9 +++++---- nireports/tests/test_dwi.py | 2 ++ nireports/tests/test_reportlets.py | 1 + 10 files changed, 36 insertions(+), 17 deletions(-) diff --git a/nireports/assembler/tests/test_report.py b/nireports/assembler/tests/test_report.py index 24d1b156..97ff7a4e 100644 --- a/nireports/assembler/tests/test_report.py +++ b/nireports/assembler/tests/test_report.py @@ -87,6 +87,7 @@ def bids_sessions(tmpdir_factory): file_path = svg_dir / bids_path file_path.ensure() f.savefig(str(file_path)) + f.clf() # create anatomical data anat_opts = [ @@ -110,6 +111,7 @@ def bids_sessions(tmpdir_factory): file_path = svg_dir / bids_path file_path.ensure() f.savefig(str(file_path)) + f.clf() return svg_dir.dirname diff --git a/nireports/interfaces/dmri.py b/nireports/interfaces/dmri.py index abd6008f..891215fe 100644 --- a/nireports/interfaces/dmri.py +++ b/nireports/interfaces/dmri.py @@ -93,5 +93,7 @@ def _run_interface(self, runtime): ) out_figure.savefig(self._results["out_file"], format="svg", dpi=300) + out_figure.clf() + out_figure = None return runtime diff --git a/nireports/interfaces/fmri.py b/nireports/interfaces/fmri.py index 6bf8bf22..fd635116 100644 --- a/nireports/interfaces/fmri.py +++ b/nireports/interfaces/fmri.py @@ -98,7 +98,7 @@ def _run_interface(self, runtime): else nifti_timeseries(input_data, seg_file) ) - fig = fMRIPlot( + fMRIPlot( dataset, segments=segments, spikes_files=( @@ -109,6 +109,6 @@ def _run_interface(self, runtime): units={"outliers": "%", "FD": "mm"}, vlines={"FD": [self.inputs.fd_thres]}, nskip=self.inputs.drop_trs, - ).plot() - fig.savefig(self._results["out_file"], bbox_inches="tight") + ).plot(out_file=self._results["out_file"]) + return runtime diff --git a/nireports/reportlets/modality/func.py b/nireports/reportlets/modality/func.py index 73aabc83..7a8d324b 100644 --- a/nireports/reportlets/modality/func.py +++ b/nireports/reportlets/modality/func.py @@ -89,22 +89,25 @@ def __init__( for sp_file in spikes_files: self.spikes.append((np.loadtxt(sp_file), None, False)) - def plot(self, figure=None): + def plot(self, figure=None, out_file=None): """Main plotter""" - import seaborn as sns - - sns.set_style("whitegrid") - sns.set_context("paper", font_scale=0.8) if figure is None: - figure = plt.gcf() + figure = plt.figure(figsize=(19.2, 32)) nconfounds = len(self.confounds) nspikes = len(self.spikes) nrows = 1 + nconfounds + nspikes # Create grid - grid = GridSpec(nrows, 1, wspace=0.0, hspace=0.05, height_ratios=[1] * (nrows - 1) + [5]) + grid = GridSpec( + nrows, + 1, + figure=figure, + wspace=0.0, + hspace=0.05, + height_ratios=[1] * (nrows - 1) + [5], + ) grid_id = 0 for tsz, name, iszs in self.spikes: @@ -130,4 +133,11 @@ def plot(self, figure=None): drop_trs=self.nskip, cmap="paired" if self.paired_carpet else None, ) + + if out_file is not None: + figure.savefig(out_file, bbox_inches="tight") + plt.close(figure) + figure = None + return out_file + return figure diff --git a/nireports/reportlets/nuisance.py b/nireports/reportlets/nuisance.py index e9b58714..c7bb11a6 100644 --- a/nireports/reportlets/nuisance.py +++ b/nireports/reportlets/nuisance.py @@ -213,6 +213,7 @@ def plot_qi2(x_grid, ref_pdf, fit_pdf, ref_data, cutoff_idx, out_file=None): out_file = op.abspath("qi2_plot.svg") fig.savefig(out_file, bbox_inches="tight", pad_inches=0, dpi=300) + plt.close(fig=fig) return out_file @@ -340,7 +341,6 @@ def plot_carpet( height_ratios=[len(v) for v in segments.values()], ) - label = "" for i, (_, indices) in enumerate(segments.items()): # Carpet plot ax = plt.subplot(gs[i]) @@ -397,7 +397,7 @@ def plot_carpet( ax.set_title(title) if nsegments == 1: - ax.set_ylabel(label) + ax.set_ylabel(segments.keys()[0]) if legend: from matplotlib.patches import Patch diff --git a/nireports/reportlets/surface.py b/nireports/reportlets/surface.py index 37d36764..b571a18e 100644 --- a/nireports/reportlets/surface.py +++ b/nireports/reportlets/surface.py @@ -158,7 +158,7 @@ def get_surface_meshes(density, surface_type): if output_file is not None: figure.savefig(output_file, bbox_inches="tight", dpi=400) - plt.close(figure) + figure.clf() return output_file return figure diff --git a/nireports/reportlets/utils.py b/nireports/reportlets/utils.py index 7b15adee..f63e5074 100644 --- a/nireports/reportlets/utils.py +++ b/nireports/reportlets/utils.py @@ -153,6 +153,7 @@ def svg2str(display_object, dpi=300): display_object.frame_axes.figure.savefig( image_buf, dpi=dpi, format="svg", facecolor="k", edgecolor="k" ) + display_object.frame_axes.figure.clf() image_buf.seek(0) return image_buf.getvalue() diff --git a/nireports/reportlets/xca.py b/nireports/reportlets/xca.py index 2476465a..97575fd7 100644 --- a/nireports/reportlets/xca.py +++ b/nireports/reportlets/xca.py @@ -367,9 +367,10 @@ def compcor_variance_plot( ax[m].spines[side].set_visible(False) if output_file is not None: - figure = plt.gcf() - figure.savefig(output_file, bbox_inches="tight") - plt.close(figure) - figure = None + if fig is None: + fig = plt.gcf() + fig.savefig(output_file, bbox_inches="tight") + fig.clf() + fig = None return output_file return ax diff --git a/nireports/tests/test_dwi.py b/nireports/tests/test_dwi.py index 382eda1f..9bb27582 100644 --- a/nireports/tests/test_dwi.py +++ b/nireports/tests/test_dwi.py @@ -58,6 +58,7 @@ def test_plot_dwi(tmp_path, testdata_path, outdir): if outdir is not None: plt.savefig(outdir / f"{stem}.svg", bbox_inches="tight") + plt.close(plt.gcf()) @pytest.mark.parametrize( @@ -77,6 +78,7 @@ def test_plot_gradients(tmp_path, testdata_path, dwi_btable, outdir): if outdir is not None: plt.savefig(outdir / f"{dwi_btable}.svg", bbox_inches="tight") + plt.close(plt.gcf()) def test_plot_tissue_values(tmp_path): diff --git a/nireports/tests/test_reportlets.py b/nireports/tests/test_reportlets.py index 3f6cd1ee..2f2f9e07 100644 --- a/nireports/tests/test_reportlets.py +++ b/nireports/tests/test_reportlets.py @@ -173,6 +173,7 @@ def test_fmriplot(input_files, testdata_path, outdir): outdir / f"fmriplot_{dtype}{has_seg}.svg", bbox_inches="tight", ) + fig.clf() def test_plot_melodic_components(tmp_path, outdir):