Skip to content

Commit

Permalink
Merge pull request #734 from VincentRouvreau/fix/persistence_graphica…
Browse files Browse the repository at this point in the history
…l_plot_tools_default_behaviour

legend=True is the default behaviour for persistence graphical plot tools
  • Loading branch information
VincentRouvreau authored Feb 2, 2023
2 parents 84c749d + b7a5257 commit a54745f
Show file tree
Hide file tree
Showing 3 changed files with 96 additions and 89 deletions.
13 changes: 4 additions & 9 deletions src/python/doc/persistence_graphical_tools_user.rst
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,7 @@ This function can display the persistence result as a diagram:
# rips_on_tore3D_1307.pers obtained from write_persistence_diagram method
persistence_file=gudhi.__root_source_dir__ + \
'/data/persistence_diagram/rips_on_tore3D_1307.pers'
ax = gudhi.plot_persistence_diagram(persistence_file=persistence_file,
legend=True)
ax = gudhi.plot_persistence_diagram(persistence_file=persistence_file)
# We can modify the title, aspect, etc.
ax.set_title("Persistence diagram of a torus")
ax.set_aspect("equal") # forces to be square shaped
Expand Down Expand Up @@ -80,15 +79,11 @@ If you want more information on a specific dimension, for instance:
persistence_file=gudhi.__root_source_dir__ + \
'/data/persistence_diagram/rips_on_tore3D_1307.pers'
birth_death = gudhi.read_persistence_intervals_in_dimension(
persistence_file=persistence_file,
only_this_dim=1)
pers_diag = [(1, elt) for elt in birth_death]
persistence_file=persistence_file, only_this_dim=1)
# Use subplots to display diagram and density side by side
fig, axes = plt.subplots(nrows=1, ncols=2, figsize=(12, 5))
gudhi.plot_persistence_diagram(persistence=pers_diag,
axes=axes[0])
gudhi.plot_persistence_density(persistence=pers_diag,
dimension=1, legend=True, axes=axes[1])
gudhi.plot_persistence_diagram(persistence=birth_death, axes=axes[0])
gudhi.plot_persistence_density(persistence=birth_death, axes=axes[1])
plt.show()

LaTeX support
Expand Down
162 changes: 85 additions & 77 deletions src/python/gudhi/persistence_graphical_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
#
# Modification(s):
# - 2020/02 Theo Lacombe: Added more options for improved rendering and more flexibility.
# - 2022/11 Vincent Rouvreau: "Automatic" legend display detected by _array_handler that returns if the persistence
# was a nx2 array.
# - YYYY/MM Author: Description of the modification

from os import path
Expand Down Expand Up @@ -56,20 +58,22 @@ def __min_birth_max_death(persistence, band=0.0):

def _array_handler(a):
"""
:param a: if array, assumes it is a (n x 2) np.array and return a
:param a: if array, assumes it is a (n x 2) np.array and returns a
persistence-compatible list (padding with 0), so that the
plot can be performed seamlessly.
:returns: * List[dimension, [birth, death]] Persistence, compatible with plot functions, list.
* boolean Modification status (True if output is different from input)
"""
if isinstance(a[0][1], (np.floating, float)):
return [[0, x] for x in a]
return [[0, x] for x in a], True
else:
return a
return a, False


def _limit_to_max_intervals(persistence, max_intervals, key):
"""This function returns truncated persistence if length is bigger than max_intervals.
:param persistence: Persistence intervals values list. Can be grouped by dimension or not.
:type persistence: an array of (dimension, array of (birth, death)) or an array of (birth, death).
:type persistence: an array of (dimension, (birth, death)) or an array of (birth, death).
:param max_intervals: maximal number of intervals to display.
Selected intervals are those with the longest life time. Set it
to 0 to see all. Default value is 1000.
Expand Down Expand Up @@ -107,37 +111,38 @@ def plot_persistence_barcode(
alpha=0.6,
max_intervals=20000,
inf_delta=0.1,
legend=False,
legend=None,
colormap=None,
axes=None,
fontsize=16,
):
"""This function plots the persistence bar code from persistence values list
, a np.array of shape (N x 2) (representing a diagram
in a single homology dimension),
, a np.array of shape (N x 2) (representing a diagram
in a single homology dimension),
or from a `persistence diagram <fileformats.html#persistence-diagram>`_ file.
:param persistence: Persistence intervals values list. Can be grouped by dimension or not.
:type persistence: an array of (dimension, array of (birth, death)) or an array of (birth, death).
:type persistence: an array of (dimension, (birth, death)) or an array of (birth, death)
:param persistence_file: A `persistence diagram <fileformats.html#persistence-diagram>`_ file style name
(reset persistence if both are set).
:type persistence_file: string
:param alpha: barcode transparency value (0.0 transparent through 1.0
opaque - default is 0.6).
:type alpha: float.
:type alpha: float
:param max_intervals: maximal number of intervals to display.
Selected intervals are those with the longest life time. Set it
to 0 to see all. Default value is 20000.
:type max_intervals: int.
:type max_intervals: int
:param inf_delta: Infinity is placed at :code:`((max_death - min_birth) x
inf_delta)` above :code:`max_death` value. A reasonable value is
between 0.05 and 0.5 - default is 0.1.
:type inf_delta: float.
:param legend: Display the dimension color legend (default is False).
:type legend: boolean.
:type inf_delta: float
:param legend: Display the dimension color legend. Default is None, meaning the legend is displayed if dimension
is specified in the persistence argument, and not displayed if dimension is not specified.
:type legend: boolean or None
:param colormap: A matplotlib-like qualitative colormaps. Default is None
which means :code:`matplotlib.cm.Set1.colors`.
:type colormap: tuple of colors (3-tuple of float between 0. and 1.).
:type colormap: tuple of colors (3-tuple of float between 0. and 1.)
:param axes: A matplotlib-like subplot axes. If None, the plot is drawn on
a new set of axes.
:type axes: `matplotlib.axes.Axes`
Expand All @@ -157,6 +162,8 @@ def plot_persistence_barcode(
plt.rc("text", usetex=False)
plt.rc("font", family="DejaVu Sans")

# By default, let's say the persistence is not an array of shape (N x 2) - Can be from a persistence file
nx2_array = False
if persistence_file != "":
if path.isfile(persistence_file):
# Reset persistence
Expand All @@ -169,7 +176,7 @@ def plot_persistence_barcode(
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file)

try:
persistence = _array_handler(persistence)
persistence, nx2_array = _array_handler(persistence)
persistence = _limit_to_max_intervals(
persistence, max_intervals, key=lambda life_time: life_time[1][1] - life_time[1][0]
)
Expand All @@ -190,16 +197,21 @@ def plot_persistence_barcode(
if colormap == None:
colormap = plt.cm.Set1.colors

x=[birth for (dim,(birth,death)) in persistence]
y=[(death - birth) if death != float("inf") else (infinity - birth) for (dim,(birth,death)) in persistence]
c=[colormap[dim] for (dim,(birth,death)) in persistence]
x = [birth for (dim, (birth, death)) in persistence]
y = [(death - birth) if death != float("inf") else (infinity - birth) for (dim, (birth, death)) in persistence]
c = [colormap[dim] for (dim, (birth, death)) in persistence]

axes.barh(range(len(x)), y, left=x, alpha=alpha, color=c, linewidth=0)

if legend is None and not nx2_array:
# By default, if persistence is an array of (dimension, (birth, death)), display the legend
legend = True

if legend:
dimensions = set(item[0] for item in persistence)
axes.legend(
handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions], loc="lower right",
handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions],
loc="best",
)

axes.set_title("Persistence barcode", fontsize=fontsize)
Expand All @@ -222,7 +234,7 @@ def plot_persistence_diagram(
band=0.0,
max_intervals=1000000,
inf_delta=0.1,
legend=False,
legend=None,
colormap=None,
axes=None,
fontsize=16,
Expand All @@ -233,28 +245,29 @@ def plot_persistence_diagram(
homology dimension, or from a `persistence diagram <fileformats.html#persistence-diagram>`_ file`.
:param persistence: Persistence intervals values list. Can be grouped by dimension or not.
:type persistence: an array of (dimension, array of (birth, death)) or an array of (birth, death).
:type persistence: an array of (dimension, (birth, death)) or an array of (birth, death)
:param persistence_file: A `persistence diagram <fileformats.html#persistence-diagram>`_ file style name
(reset persistence if both are set).
:type persistence_file: string
:param alpha: plot transparency value (0.0 transparent through 1.0
opaque - default is 0.6).
:type alpha: float.
:type alpha: float
:param band: band (not displayed if :math:`\leq` 0. - default is 0.)
:type band: float.
:type band: float
:param max_intervals: maximal number of intervals to display.
Selected intervals are those with the longest life time. Set it
to 0 to see all. Default value is 1000000.
:type max_intervals: int.
:type max_intervals: int
:param inf_delta: Infinity is placed at :code:`((max_death - min_birth) x
inf_delta)` above :code:`max_death` value. A reasonable value is
between 0.05 and 0.5 - default is 0.1.
:type inf_delta: float.
:param legend: Display the dimension color legend (default is False).
:type legend: boolean.
:type inf_delta: float
:param legend: Display the dimension color legend. Default is None, meaning the legend is displayed if dimension
is specified in the persistence argument, and not displayed if dimension is not specified.
:type legend: boolean or None
:param colormap: A matplotlib-like qualitative colormaps. Default is None
which means :code:`matplotlib.cm.Set1.colors`.
:type colormap: tuple of colors (3-tuple of float between 0. and 1.).
:type colormap: tuple of colors (3-tuple of float between 0. and 1.)
:param axes: A matplotlib-like subplot axes. If None, the plot is drawn on
a new set of axes.
:type axes: `matplotlib.axes.Axes`
Expand All @@ -276,6 +289,8 @@ def plot_persistence_diagram(
plt.rc("text", usetex=False)
plt.rc("font", family="DejaVu Sans")

# By default, let's say the persistence is not an array of shape (N x 2) - Can be from a persistence file
nx2_array = False
if persistence_file != "":
if path.isfile(persistence_file):
# Reset persistence
Expand All @@ -288,7 +303,7 @@ def plot_persistence_diagram(
raise FileNotFoundError(errno.ENOENT, os.strerror(errno.ENOENT), persistence_file)

try:
persistence = _array_handler(persistence)
persistence, nx2_array = _array_handler(persistence)
persistence = _limit_to_max_intervals(
persistence, max_intervals, key=lambda life_time: life_time[1][1] - life_time[1][0]
)
Expand Down Expand Up @@ -324,12 +339,12 @@ def plot_persistence_diagram(
# line display of equation : birth = death
axes.plot([axis_start, axis_end], [axis_start, axis_end], linewidth=1.0, color="k")

x=[birth for (dim,(birth,death)) in persistence]
y=[death if death != float("inf") else infinity for (dim,(birth,death)) in persistence]
c=[colormap[dim] for (dim,(birth,death)) in persistence]
x = [birth for (dim, (birth, death)) in persistence]
y = [death if death != float("inf") else infinity for (dim, (birth, death)) in persistence]
c = [colormap[dim] for (dim, (birth, death)) in persistence]

axes.scatter(x,y,alpha=alpha,color=c)
if float("inf") in (death for (dim,(birth,death)) in persistence):
axes.scatter(x, y, alpha=alpha, color=c)
if float("inf") in (death for (dim, (birth, death)) in persistence):
# infinity line and text
axes.plot([axis_start, axis_end], [infinity, infinity], linewidth=1.0, color="k", alpha=alpha)
# Infinity label
Expand All @@ -341,9 +356,16 @@ def plot_persistence_diagram(
axes.set_yticks(yt)
axes.set_yticklabels(ytl)

if legend is None and not nx2_array:
# By default, if persistence is an array of (dimension, (birth, death)), display the legend
legend = True

if legend:
dimensions = list(set(item[0] for item in persistence))
axes.legend(handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions])
axes.legend(
handles=[mpatches.Patch(color=colormap[dim], label=str(dim)) for dim in dimensions],
loc="lower right",
)

axes.set_xlabel("Birth", fontsize=fontsize)
axes.set_ylabel("Death", fontsize=fontsize)
Expand All @@ -364,61 +386,46 @@ def plot_persistence_density(
max_intervals=1000,
dimension=None,
cmap=None,
legend=False,
legend=True,
axes=None,
fontsize=16,
greyblock=False,
):
"""This function plots the persistence density from persistence
values list, np.array of shape (N x 2) representing a diagram
in a single homology dimension,
or from a `persistence diagram <fileformats.html#persistence-diagram>`_ file.
Be aware that this function does not distinguish the dimension, it is
up to you to select the required one. This function also does not handle
degenerate data set (scipy correlation matrix inversion can fail).
"""This function plots the persistence density from persistence values list, np.array of shape (N x 2) representing
a diagram in a single homology dimension, or from a `persistence diagram <fileformats.html#persistence-diagram>`_
file. Be aware that this function does not distinguish the dimension, it is up to you to select the required one.
This function also does not handle degenerate data set (scipy correlation matrix inversion can fail).
:Requires: `SciPy <installation.html#scipy>`_
:param persistence: Persistence intervals values list.
Can be grouped by dimension or not.
:type persistence: an array of (dimension, array of (birth, death))
or an array of (birth, death).
:param persistence: Persistence intervals values list. Can be grouped by dimension or not.
:type persistence: an array of (dimension, (birth, death)) or an array of (birth, death)
:param persistence_file: A `persistence diagram <fileformats.html#persistence-diagram>`_
file style name (reset persistence if both are set).
:type persistence_file: string
:param nbins: Evaluate a gaussian kde on a regular grid of nbins x
nbins over data extents (default is 300)
:type nbins: int.
:param bw_method: The method used to calculate the estimator
bandwidth. This can be 'scott', 'silverman', a scalar constant
or a callable. If a scalar, this will be used directly as
kde.factor. If a callable, it should take a gaussian_kde
instance as only parameter and return a scalar. If None
(default), 'scott' is used. See
:param nbins: Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents (default is 300)
:type nbins: int
:param bw_method: The method used to calculate the estimator bandwidth. This can be 'scott', 'silverman', a scalar
constant or a callable. If a scalar, this will be used directly as kde.factor. If a callable, it should take a
gaussian_kde instance as only parameter and return a scalar. If None (default), 'scott' is used. See
`scipy.stats.gaussian_kde documentation
<http://scipy.github.io/devdocs/generated/scipy.stats.gaussian_kde.html>`_
<https://docs.scipy.org/doc/scipy/reference/generated/scipy.stats.gaussian_kde.html>`_
for more details.
:type bw_method: str, scalar or callable, optional.
:param max_intervals: maximal number of points used in the density
estimation.
Selected intervals are those with the longest life time. Set it
to 0 to see all. Default value is 1000.
:type max_intervals: int.
:param dimension: the dimension to be selected in the intervals
(default is None to mix all dimensions).
:type dimension: int.
:param cmap: A matplotlib colormap (default is
matplotlib.pyplot.cm.hot_r).
:type cmap: cf. matplotlib colormap.
:param legend: Display the color bar values (default is False).
:type legend: boolean.
:param axes: A matplotlib-like subplot axes. If None, the plot is drawn on
a new set of axes.
:type bw_method: str, scalar or callable, optional
:param max_intervals: maximal number of points used in the density estimation. Selected intervals are those with
the longest life time. Set it to 0 to see all. Default value is 1000.
:type max_intervals: int
:param dimension: the dimension to be selected in the intervals (default is None to mix all dimensions).
:type dimension: int
:param cmap: A matplotlib colormap (default is matplotlib.pyplot.cm.hot_r).
:type cmap: cf. matplotlib colormap
:param legend: Display the color bar values (default is True).
:type legend: boolean
:param axes: A matplotlib-like subplot axes. If None, the plot is drawn on a new set of axes.
:type axes: `matplotlib.axes.Axes`
:param fontsize: Fontsize to use in axis.
:type fontsize: int
:param greyblock: if we want to plot a grey patch on the lower half plane
for nicer rendering. Default False.
:param greyblock: if we want to plot a grey patch on the lower half plane for nicer rendering. Default False.
:type greyblock: boolean
:returns: (`matplotlib.axes.Axes`): The axes on which the plot was drawn.
"""
Expand Down Expand Up @@ -454,7 +461,7 @@ def plot_persistence_density(

try:
# if not read from file but given by an argument
persistence = _array_handler(persistence)
persistence, _ = _array_handler(persistence)
persistence_dim = np.array(
[
(dim_interval[1][0], dim_interval[1][1])
Expand All @@ -480,7 +487,8 @@ def plot_persistence_density(
# Evaluate a gaussian kde on a regular grid of nbins x nbins over data extents
k = kde.gaussian_kde([birth, death], bw_method=bw_method)
xi, yi = np.mgrid[
birth_min : birth_max : nbins * 1j, death_min : death_max : nbins * 1j,
birth_min : birth_max : nbins * 1j,
death_min : death_max : nbins * 1j,
]
zi = k(np.vstack([xi.flatten(), yi.flatten()]))
# Make the plot
Expand Down
10 changes: 7 additions & 3 deletions src/python/test/test_persistence_graphical_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,19 +17,23 @@

def test_array_handler():
diags = np.array([[1, 2], [3, 4], [5, 6]], float)
arr_diags = gd.persistence_graphical_tools._array_handler(diags)
arr_diags, nx2_array = gd.persistence_graphical_tools._array_handler(diags)
assert nx2_array
for idx in range(len(diags)):
assert arr_diags[idx][0] == 0
np.testing.assert_array_equal(arr_diags[idx][1], diags[idx])

diags = [(1.0, 2.0), (3.0, 4.0), (5.0, 6.0)]
arr_diags = gd.persistence_graphical_tools._array_handler(diags)
arr_diags, nx2_array = gd.persistence_graphical_tools._array_handler(diags)
assert nx2_array
for idx in range(len(diags)):
assert arr_diags[idx][0] == 0
assert arr_diags[idx][1] == diags[idx]

diags = [(0, (1.0, 2.0)), (0, (3.0, 4.0)), (0, (5.0, 6.0))]
assert gd.persistence_graphical_tools._array_handler(diags) == diags
arr_diags, nx2_array = gd.persistence_graphical_tools._array_handler(diags)
assert not nx2_array
assert arr_diags == diags


def test_min_birth_max_death():
Expand Down

0 comments on commit a54745f

Please sign in to comment.