Skip to content

Commit

Permalink
fixes to get axes and figure args working
Browse files Browse the repository at this point in the history
  • Loading branch information
akhanf committed Dec 1, 2024
1 parent 1150894 commit a04d3e2
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 17 deletions.
File renamed without changes.
38 changes: 21 additions & 17 deletions hippunfold_plot/plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from nilearn.plotting import plot_surf
from hippunfold_plot.utils import get_surf_limits, get_data_limits, get_resource_path, check_surf_map_is_label_gii, get_legend_elements_from_label_gii
from typing import Union, Tuple, Optional, List
import numpy as np

def plot_hipp_surf(surf_map: Union[str, list],
density: str = '0p5mm',
Expand Down Expand Up @@ -72,19 +73,14 @@ def plot_hipp_surf(surf_map: Union[str, list],
Returns
-------
fig : matplotlib.figure.Figure
figure : matplotlib.figure.Figure
The figure object.
mappable : matplotlib.cm.ScalarMappable, optional
The mappable object, if return_mappable is True.
Notes
-----
By default, this function will plot one hemisphere (left by default) in both canonical and unfolded space.
Both surfaces can be plotted with hemi=None, but the same surf_map will be plotted on both.
Use return_mappable=True if you want to make a colorbar afterwards, e.g.:
fig, mappable = plot_hipp_surf(..., return_mappable=True)
plt.colorbar(mappable, shrink=0.5) # shrink makes it smaller which is recommended
"""
# Validate inputs
valid_densities = ['unfoldiso', '0p5mm', '1mm', '2mm']
Expand Down Expand Up @@ -128,22 +124,30 @@ def plot_hipp_surf(surf_map: Union[str, list],
if num_plots > 1:
raise ValueError("Multiple plots requested, but only one axis provided.")
axes = [axes]
elif isinstance(axes, list):
elif isinstance(axes, np.ndarray):
if len(axes) != num_plots:
raise ValueError(f"Expected {num_plots} axes, but got {len(axes)}.")
else:
raise ValueError("Invalid type for 'axes'. Expected matplotlib.axes.Axes or list of matplotlib.axes.Axes.")
raise ValueError("Invalid type for 'axes'. Expected matplotlib.axes.Axes or array of matplotlib.axes.Axes.")


# Create a figure if not provided
if fig is None:
fig = plt.figure(figsize=figsize, dpi=dpi)

if figure is None:
figure = plt.figure(figsize=figsize, dpi=dpi)

# Define positions for 4 tall side-by-side axes
positions = [
[0.05, 0.1, 0.2, 0.8], # Left, bottom, width, height
[0.18, 0.1, 0.2, 0.8],
[0.30, 0.1, 0.2, 0.8],
[0.43, 0.1, 0.2, 0.8],
[0.55, 0.1, 0.2, 0.8],
]

# Create axes if not provided
if axes is None:
axes = [fig.add_subplot(1, num_plots, i + 1, projection='3d') for i in range(num_plots)]


axes = [figure.add_axes(positions[i], projection='3d') for i in range(num_plots)]

pos=0

# Build the composite plot
Expand All @@ -153,7 +157,7 @@ def plot_hipp_surf(surf_map: Union[str, list],
ax = axes[pos]
plot_surf(surf_mesh=surf_gii.format(hemi=h,space=s,density=density),
axes=ax,
figure=fig,
figure=figure,
**plot_kwargs)
(xlim_kwargs,ylim_kwargs) = get_surf_limits(surf_mesh=surf_gii.format(hemi=h,space=s,density=density))
ax.set_xlim(**xlim_kwargs)
Expand All @@ -168,7 +172,7 @@ def plot_hipp_surf(surf_map: Union[str, list],
norm = mpl.colors.Normalize(vmin=vmin if vmin else datamin, vmax=vmax if vmax else datamax) # Match your data range
sm = mpl.cm.ScalarMappable(cmap=cmap, norm=norm)
sm.set_array([]) # Dummy array for ScalarMappable
plt.colorbar(sm,ax=fig.axes,shrink=colorbar_shrink)
plt.colorbar(sm,ax=figure.axes,shrink=colorbar_shrink)

return fig
return figure

0 comments on commit a04d3e2

Please sign in to comment.