Skip to content

Commit

Permalink
fix: address memory issues and corruption in fMRIPlot
Browse files Browse the repository at this point in the history
At the same time, this PR ensures that all calls to ``plt.savefig()``
are followed by closing the figure.

Resolves: #129.
Closes: #130.
  • Loading branch information
oesteban committed Aug 24, 2024
1 parent e0a30e6 commit bab640a
Show file tree
Hide file tree
Showing 10 changed files with 36 additions and 17 deletions.
2 changes: 2 additions & 0 deletions nireports/assembler/tests/test_report.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def bids_sessions(tmpdir_factory):
file_path = svg_dir / bids_path
file_path.ensure()
f.savefig(str(file_path))
f.clf()

# create anatomical data
anat_opts = [
Expand All @@ -110,6 +111,7 @@ def bids_sessions(tmpdir_factory):
file_path = svg_dir / bids_path
file_path.ensure()
f.savefig(str(file_path))
f.clf()

return svg_dir.dirname

Expand Down
2 changes: 2 additions & 0 deletions nireports/interfaces/dmri.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,5 +93,7 @@ def _run_interface(self, runtime):
)

out_figure.savefig(self._results["out_file"], format="svg", dpi=300)
out_figure.clf()
out_figure = None

return runtime
6 changes: 3 additions & 3 deletions nireports/interfaces/fmri.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def _run_interface(self, runtime):
else nifti_timeseries(input_data, seg_file)
)

fig = fMRIPlot(
fMRIPlot(
dataset,
segments=segments,
spikes_files=(
Expand All @@ -109,6 +109,6 @@ def _run_interface(self, runtime):
units={"outliers": "%", "FD": "mm"},
vlines={"FD": [self.inputs.fd_thres]},
nskip=self.inputs.drop_trs,
).plot()
fig.savefig(self._results["out_file"], bbox_inches="tight")
).plot(out_file=self._results["out_file"])

return runtime
24 changes: 17 additions & 7 deletions nireports/reportlets/modality/func.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,22 +89,25 @@ def __init__(
for sp_file in spikes_files:
self.spikes.append((np.loadtxt(sp_file), None, False))

def plot(self, figure=None):
def plot(self, figure=None, out_file=None):
"""Main plotter"""
import seaborn as sns

sns.set_style("whitegrid")
sns.set_context("paper", font_scale=0.8)

if figure is None:
figure = plt.gcf()
figure = plt.figure(figsize=(19.2, 32))

nconfounds = len(self.confounds)
nspikes = len(self.spikes)
nrows = 1 + nconfounds + nspikes

# Create grid
grid = GridSpec(nrows, 1, wspace=0.0, hspace=0.05, height_ratios=[1] * (nrows - 1) + [5])
grid = GridSpec(
nrows,
1,
figure=figure,
wspace=0.0,
hspace=0.05,
height_ratios=[1] * (nrows - 1) + [5],
)

grid_id = 0
for tsz, name, iszs in self.spikes:
Expand All @@ -130,4 +133,11 @@ def plot(self, figure=None):
drop_trs=self.nskip,
cmap="paired" if self.paired_carpet else None,
)

if out_file is not None:
figure.savefig(out_file, bbox_inches="tight")
plt.close(figure)
figure = None
return out_file

return figure
4 changes: 2 additions & 2 deletions nireports/reportlets/nuisance.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def plot_qi2(x_grid, ref_pdf, fit_pdf, ref_data, cutoff_idx, out_file=None):
out_file = op.abspath("qi2_plot.svg")

fig.savefig(out_file, bbox_inches="tight", pad_inches=0, dpi=300)
plt.close(fig=fig)
return out_file


Expand Down Expand Up @@ -340,7 +341,6 @@ def plot_carpet(
height_ratios=[len(v) for v in segments.values()],
)

label = ""
for i, (_, indices) in enumerate(segments.items()):
# Carpet plot
ax = plt.subplot(gs[i])
Expand Down Expand Up @@ -397,7 +397,7 @@ def plot_carpet(
ax.set_title(title)

if nsegments == 1:
ax.set_ylabel(label)
ax.set_ylabel(segments.keys()[0])

if legend:
from matplotlib.patches import Patch
Expand Down
2 changes: 1 addition & 1 deletion nireports/reportlets/surface.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def get_surface_meshes(density, surface_type):

if output_file is not None:
figure.savefig(output_file, bbox_inches="tight", dpi=400)
plt.close(figure)
figure.clf()
return output_file

return figure
Expand Down
1 change: 1 addition & 0 deletions nireports/reportlets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ def svg2str(display_object, dpi=300):
display_object.frame_axes.figure.savefig(
image_buf, dpi=dpi, format="svg", facecolor="k", edgecolor="k"
)
display_object.frame_axes.figure.clf()
image_buf.seek(0)
return image_buf.getvalue()

Expand Down
9 changes: 5 additions & 4 deletions nireports/reportlets/xca.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,9 +367,10 @@ def compcor_variance_plot(
ax[m].spines[side].set_visible(False)

if output_file is not None:
figure = plt.gcf()
figure.savefig(output_file, bbox_inches="tight")
plt.close(figure)
figure = None
if fig is None:
fig = plt.gcf()
fig.savefig(output_file, bbox_inches="tight")
fig.clf()
fig = None
return output_file
return ax
2 changes: 2 additions & 0 deletions nireports/tests/test_dwi.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ def test_plot_dwi(tmp_path, testdata_path, outdir):

if outdir is not None:
plt.savefig(outdir / f"{stem}.svg", bbox_inches="tight")
plt.close(plt.gcf())


@pytest.mark.parametrize(
Expand All @@ -77,6 +78,7 @@ def test_plot_gradients(tmp_path, testdata_path, dwi_btable, outdir):

if outdir is not None:
plt.savefig(outdir / f"{dwi_btable}.svg", bbox_inches="tight")
plt.close(plt.gcf())


def test_plot_tissue_values(tmp_path):
Expand Down
1 change: 1 addition & 0 deletions nireports/tests/test_reportlets.py
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def test_fmriplot(input_files, testdata_path, outdir):
outdir / f"fmriplot_{dtype}{has_seg}.svg",
bbox_inches="tight",
)
fig.clf()


def test_plot_melodic_components(tmp_path, outdir):
Expand Down

0 comments on commit bab640a

Please sign in to comment.