Skip to content

Commit

Permalink
updates for handking figure and axes
Browse files Browse the repository at this point in the history
  • Loading branch information
akhanf committed Dec 1, 2024
1 parent 33a057f commit 1150894
Showing 1 changed file with 53 additions and 20 deletions.
73 changes: 53 additions & 20 deletions hippunfold_plot/plotting.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,27 @@
import matplotlib.pyplot as plt
from nilearn.plotting import plot_surf
from hippunfold_plot.utils import get_surf_limits, get_data_limits, get_resource_path, check_surf_map_is_label_gii, get_legend_elements_from_label_gii

def plot_hipp_surf(surf_map, density='0p5mm', hemi='left', space=None, figsize=(12, 8), dpi=300, vmin=None, vmax=None, colorbar=False, colorbar_shrink=0.25, cmap=None, view='dorsal', avg_method='median', bg_on_data=True, alpha=0.1, darkness=2, **kwargs):
from typing import Union, Tuple, Optional, List

def plot_hipp_surf(surf_map: Union[str, list],
density: str = '0p5mm',
hemi: str = 'left',
space: Optional[str] = None,
figsize: Tuple[int, int] = (12, 8),
dpi: int = 300,
vmin: Optional[float] = None,
vmax: Optional[float] = None,
colorbar: bool = False,
colorbar_shrink: float = 0.25,
cmap: Optional[Union[str, plt.cm.ScalarMappable]] = None,
view: str = 'dorsal',
avg_method: str = 'median',
bg_on_data: bool = True,
alpha: float = 0.1,
darkness: float = 2,
axes: Optional[Union[plt.Axes, List[plt.Axes]]] = None,
figure: Optional[plt.Figure] = None,
**kwargs) -> plt.Figure:
"""Plot hippocampal surface map.
This function plots a surface map of the hippocampus, which can be a label-hippdentate shape.gii, func.gii, or a Vx1 array
Expand Down Expand Up @@ -43,6 +62,11 @@ def plot_hipp_surf(surf_map, density='0p5mm', hemi='left', space=None, figsize=(
The alpha transparency level. Default is 0.1.
darkness : float, optional
The darkness level of the background. Default is 2.
axes : matplotlib.axes.Axes or list of matplotlib.axes.Axes, optional
Axes to plot on. If None, new axes will be created. If a single axis is provided, it will be used for a single plot.
If multiple plots are to be made, a list of axes should be provided.
figure : matplotlib.figure.Figure, optional
The figure to plot on. If None, a new figure will be created.
**kwargs : dict
Additional arguments to pass to nilearn's plot_surf().
Expand Down Expand Up @@ -87,37 +111,46 @@ def plot_hipp_surf(surf_map, density='0p5mm', hemi='left', space=None, figsize=(

#add any user arguments
plot_kwargs.update(kwargs)

# Create a figure
fig = plt.figure(figsize=figsize,dpi=dpi) # Adjust figure size for tall axes

# Define positions for 4 tall side-by-side axes
positions = [
[0.05, 0.1, 0.2, 0.8], # Left, bottom, width, height
[0.18, 0.1, 0.2, 0.8],
[0.30, 0.1, 0.2, 0.8],
[0.43, 0.1, 0.2, 0.8],
[0.55, 0.1, 0.2, 0.8],

]

# Define the plotting order for each hemisphere
hemi_space_map = {
'left': ['unfold', 'canonical'],
'right': ['canonical', 'unfold']
}


# Determine the number of plots to be made
hemis_to_plot = [hemi] if hemi else hemi_space_map.keys()
num_plots = sum(len([space] if space else hemi_space_map[h]) for h in hemis_to_plot)

# Validate axes input
if axes is not None:
if isinstance(axes, plt.Axes):
if num_plots > 1:
raise ValueError("Multiple plots requested, but only one axis provided.")
axes = [axes]
elif isinstance(axes, list):
if len(axes) != num_plots:
raise ValueError(f"Expected {num_plots} axes, but got {len(axes)}.")
else:
raise ValueError("Invalid type for 'axes'. Expected matplotlib.axes.Axes or list of matplotlib.axes.Axes.")


# Create a figure if not provided
if fig is None:
fig = plt.figure(figsize=figsize, dpi=dpi)

# Create axes if not provided
if axes is None:
axes = [fig.add_subplot(1, num_plots, i + 1, projection='3d') for i in range(num_plots)]


pos=0

# Build the composite plot
hemis_to_plot = [hemi] if hemi else hemi_space_map.keys()
for h in hemis_to_plot:
spaces_to_plot = [space] if space else hemi_space_map[h]
for s in spaces_to_plot:

ax = fig.add_axes(positions[pos], projection='3d') # Add 3D axes

ax = axes[pos]
plot_surf(surf_mesh=surf_gii.format(hemi=h,space=s,density=density),
axes=ax,
figure=fig,
Expand Down

0 comments on commit 1150894

Please sign in to comment.