Skip to content

Commit

Permalink
Merge pull request #370 from scipp/1dplots-yaxis-label
Browse files Browse the repository at this point in the history
Add data array name on vertical axis of 1d plots
  • Loading branch information
nvaytet authored Sep 12, 2024
2 parents 5262ed7 + 8b79e87 commit b079269
Show file tree
Hide file tree
Showing 14 changed files with 294 additions and 114 deletions.
21 changes: 17 additions & 4 deletions src/plopp/backends/matplotlib/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

from ...core.utils import maybe_variable_to_number, scalar_to_string
from ...graphics.bbox import BoundingBox
from .utils import fig_to_bytes, is_sphinx_build, make_figure
from .utils import fig_to_bytes, is_sphinx_build, make_figure, make_legend


def _cursor_value_to_variable(x: float, dtype: sc.DType, unit: str) -> sc.Variable:
Expand Down Expand Up @@ -150,6 +150,17 @@ def draw(self):
"""
self.fig.canvas.draw_idle()

def update_legend(self):
"""
Update the legend on the canvas.
"""
if self._legend:
handles, labels = self.ax.get_legend_handles_labels()
if len(handles) > 1:
self.ax.legend(handles, labels, **make_legend(self._legend))
elif (leg := self.ax.get_legend()) is not None:
leg.remove()

def save(self, filename: str, **kwargs):
"""
Save the figure to file.
Expand Down Expand Up @@ -192,9 +203,10 @@ def set_axes(self, dims, units, dtypes):
self._cursor_x_prefix = self.dims['x'] + '='
self._cursor_y_prefix = self.dims['y'] + '='
self.ax.format_coord = self.format_coord
key = 'y' if 'y' in self.units else 'data'
self.bbox = BoundingBox(
ymin=maybe_variable_to_number(self._user_vmin, unit=self.units.get('y')),
ymax=maybe_variable_to_number(self._user_vmax, unit=self.units.get('y')),
ymin=maybe_variable_to_number(self._user_vmin, unit=self.units[key]),
ymax=maybe_variable_to_number(self._user_vmax, unit=self.units[key]),
)

def register_format_coord(self, formatter):
Expand All @@ -216,7 +228,8 @@ def format_coord(self, x: float, y: float) -> str:
The y coordinate of the mouse pointer.
"""
xstr = _cursor_formatter(x, self.dtypes['x'], self.units['x'])
ystr = _cursor_formatter(y, self.dtypes['y'], self.units['y'])
key = 'y' if 'y' in self.dtypes else 'data'
ystr = _cursor_formatter(y, self.dtypes[key], self.units[key])
out = f"({self._cursor_x_prefix}{xstr}, {self._cursor_y_prefix}{ystr})"
if not self._coord_formatters:
return out
Expand Down
8 changes: 1 addition & 7 deletions src/plopp/backends/matplotlib/line.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from ...graphics.bbox import BoundingBox, axis_bounds
from ..common import check_ndim, make_line_data
from .canvas import Canvas
from .utils import make_legend, parse_dicts_in_kwargs
from .utils import parse_dicts_in_kwargs


def _to_float(x):
Expand Down Expand Up @@ -137,11 +137,6 @@ def __init__(
zorder=10,
fmt="none",
)
self.update_legend()

def update_legend(self) -> None:
if self.label and self._canvas._legend:
self._ax.legend(**make_legend(self._canvas._legend))

def update(self, new_values: sc.DataArray):
"""
Expand Down Expand Up @@ -187,7 +182,6 @@ def remove(self):
self._mask.remove()
if self._error is not None:
self._error.remove()
self.update_legend()
self._canvas.draw()

@property
Expand Down
10 changes: 1 addition & 9 deletions src/plopp/backends/matplotlib/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from ...graphics.bbox import BoundingBox, axis_bounds
from ..common import check_ndim
from .canvas import Canvas
from .utils import make_legend, parse_dicts_in_kwargs
from .utils import parse_dicts_in_kwargs


class Scatter:
Expand Down Expand Up @@ -88,14 +88,6 @@ def __init__(
visible=visible_mask,
)

if self._canvas._legend:
leg_args = make_legend(self._canvas._legend)
if np.shape(s) == np.shape(self._data.coords[self._x].values):
handles, labels = self._scatter.legend_elements(prop="sizes")
self._ax.legend(handles, labels, title="Sizes", **leg_args)
if self.label:
self._ax.legend(**leg_args)

def update(self, new_values: sc.DataArray):
"""
Update the x and y positions of the data points from new data.
Expand Down
8 changes: 6 additions & 2 deletions src/plopp/backends/plotly/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,10 @@ def set_axes(self, dims, units, dtypes):
self.dims = dims
self.units = units
self.dtypes = dtypes
key = 'y' if 'y' in self.units else 'data'
self.bbox = BoundingBox(
ymin=maybe_variable_to_number(self._user_vmin, unit=self.units.get('y')),
ymax=maybe_variable_to_number(self._user_vmax, unit=self.units.get('y')),
ymin=maybe_variable_to_number(self._user_vmin, unit=self.units[key]),
ymax=maybe_variable_to_number(self._user_vmax, unit=self.units[key]),
)

@property
Expand Down Expand Up @@ -303,3 +304,6 @@ def logy(self):

def draw(self):
pass

def update_legend(self):
pass
3 changes: 3 additions & 0 deletions src/plopp/backends/pythreejs/canvas.py
Original file line number Diff line number Diff line change
Expand Up @@ -395,3 +395,6 @@ def zrange(self) -> tuple[float, float]:
@zrange.setter
def zrange(self, value: tuple[float, float]):
self.zmin, self.zmax = value

def update_legend(self):
pass
60 changes: 33 additions & 27 deletions src/plopp/graphics/colormapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,8 @@ def __init__(
# raised when making the colorbar before any call to update is made.
self.normalizer = _get_normalizer(self.norm)
self.colorbar = None
self.unit = None

self.name = None
self._unit = None
self.empty = True
self.changed = False
self.artists = {}
self.widget = None
Expand Down Expand Up @@ -221,35 +220,42 @@ def _set_artists_colors(self):
for k in self.artists.keys():
self.artists[k].set_colors(self.rgba(self.artists[k].data))

def update(self, *args, **kwargs):
def update(self):
"""
Update the colorscale bounds taking into account new values,
by either supplying a dictionary of new data or by keyword arguments.
We also update the colorbar widget if it exists.
Update the colors of all artists, if the `self.set_colors_on_update` attribute
is set to `True`.
"""
new = dict(*args, **kwargs)
for data in new.values():
if self.name is None:
self.name = data.name
# If name is None, this is the first time update is called
if self.user_vmin is not None:
self.user_vmin = maybe_variable_to_number(
self.user_vmin, unit=self.unit
)
if self.user_vmax is not None:
self.user_vmax = maybe_variable_to_number(
self.user_vmax, unit=self.unit
)
elif data.name != self.name:
self.name = ''
if self.cax is not None:
text = self.name
if self.unit is not None:
text += f'{" " if self.name else ""}[{self.unit}]'
self.cax.set_ylabel(text)
if self.set_colors_on_update:
self._set_artists_colors()

@property
def unit(self) -> str | None:
"""
Get or set the unit of the colorbar.
"""
return self._unit

@unit.setter
def unit(self, unit: str | None):
self._unit = unit
if self.user_vmin is not None:
self.user_vmin = maybe_variable_to_number(self.user_vmin, unit=self._unit)
if self.user_vmax is not None:
self.user_vmax = maybe_variable_to_number(self.user_vmax, unit=self._unit)

@property
def ylabel(self) -> str | None:
"""
Get or set the label of the colorbar axis.
"""
if self.cax is not None:
return self.cax.get_ylabel()

@ylabel.setter
def ylabel(self, lab: str):
if self.cax is not None:
self.cax.set_ylabel(lab)

def toggle_norm(self):
"""
Toggle the norm flag, between `linear` and `log`.
Expand Down
68 changes: 43 additions & 25 deletions src/plopp/graphics/graphicalview.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,15 +65,15 @@ def __init__(
self._repr_format = format
self.bbox = BoundingBox()
self.draw_on_update = True
self._data_name = None
self._data_axis = None

self.canvas = canvas_maker(
cbar=cbar,
aspect=aspect,
grid=grid,
figsize=figsize,
title=title,
vmin=vmin,
vmax=vmax,
legend=legend,
camera=camera,
ax=ax,
Expand Down Expand Up @@ -129,6 +129,7 @@ def update(self, *args, **kwargs) -> None:
new data or by keyword arguments.
"""
new = dict(*args, **kwargs)
need_legend_update = False
for key, new_values in new.items():
coords = {}
for i, direction in enumerate(self._dims):
Expand All @@ -143,42 +144,53 @@ def update(self, *args, **kwargs) -> None:
) from e

if self.canvas.empty:
self._data_name = new_values.name
axes_units = {k: coord.unit for k, coord in coords.items()}
axes_dtypes = {k: coord.dtype for k, coord in coords.items()}
if 'y' in self._dims:
self.canvas.ylabel = name_with_unit(
var=coords['y'], name=self._dims['y']
)
if self._dims['y'] in self._scale:
self.canvas.yscale = self._scale[self._dims['y']]

if set(self._dims) == {'x'}:
axes_units['data'] = new_values.unit
axes_dtypes['data'] = new_values.dtype
if self.colormapper is not None:
self.colormapper.unit = new_values.unit
axes_units['data'] = new_values.unit
axes_dtypes['data'] = new_values.dtype
self._data_axis = self.colormapper
else:
self.canvas.ylabel = name_with_unit(var=new_values.data, name="")
axes_units['y'] = new_values.unit
axes_dtypes['y'] = new_values.dtype
self._data_axis = self.canvas

self.canvas.set_axes(
dims=self._dims, units=axes_units, dtypes=axes_dtypes
)
self.canvas.xlabel = name_with_unit(
var=coords['x'], name=self._dims['x']
)
if self.colormapper is not None:
self.colormapper.unit = new_values.unit
if self._dims['x'] in self._scale:
self.canvas.xscale = self._scale[self._dims['x']]
else:
if self.colormapper is not None:
new_values.data = make_compatible(
new_values.data, unit=self.colormapper.unit

for xyz, dim in self._dims.items():
setattr(
self.canvas,
f'{xyz}label',
name_with_unit(var=coords[xyz], name=dim),
)
if dim in self._scale:
setattr(self.canvas, f'{xyz}scale', self._scale[dim])

if self._data_axis is not None:
self._data_axis.ylabel = name_with_unit(
var=new_values.data, name=self._data_name
)

else:
for xy, dim in self._dims.items():
new_values.coords[dim] = make_compatible(
coords[xy], unit=self.canvas.units[xy]
)
if 'y' not in self._dims:
if 'data' in self.canvas.units:
new_values.data = make_compatible(
new_values.data, unit=self.canvas.units['y']
new_values.data, unit=self.canvas.units['data']
)
if self._data_name and (new_values.name != self._data_name):
self._data_name = None
self._data_axis.ylabel = name_with_unit(
var=sc.scalar(0.0, unit=self.canvas.units['data']), name=''
)

if key not in self.artists:
self.artists[key] = self._artist_maker(
Expand All @@ -191,10 +203,15 @@ def update(self, *args, **kwargs) -> None:
if self.colormapper is not None:
self.colormapper[key] = self.artists[key]

need_legend_update = getattr(self.artists[key], "label", False)

self.artists[key].update(new_values=new_values)

if self.colormapper is not None:
self.colormapper.update(**new)
self.colormapper.update()

if need_legend_update:
self.canvas.update_legend()

if self.draw_on_update:
self.canvas.draw()
Expand Down Expand Up @@ -233,4 +250,5 @@ def remove(self, key: str) -> None:
"""
self.artists[key].remove()
del self.artists[key]
self.canvas.update_legend()
self.autoscale()
5 changes: 5 additions & 0 deletions src/plopp/plotting/scatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,7 @@ def scatter(
vmax: sc.Variable | float = None,
cbar: bool = False,
cmap: str = 'viridis',
legend: bool | tuple[float, float] = True,
**kwargs,
) -> FigureLike:
"""
Expand Down Expand Up @@ -81,6 +82,9 @@ def scatter(
using the data values in the supplied data array.
cmap:
The colormap to be used for the colorscale.
legend:
Show legend if ``True``. If ``legend`` is a tuple, it should contain the
``(x, y)`` coordinates of the legend's anchor point in axes coordinates.
**kwargs:
All other kwargs are forwarded the underlying plotting library.
"""
Expand All @@ -102,5 +106,6 @@ def scatter(
vmax=vmax,
cmap=cmap,
cbar=cbar,
legend=legend,
**kwargs,
)
30 changes: 30 additions & 0 deletions tests/backends/matplotlib/mpl_plot_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,36 @@ def test_legend_location():
assert leg3[0] == leg1[0]


def test_no_legend_for_single_artist():
da = data_array(ndim=1)
da.name = "Velocity"
p = pp.plot(da)
leg = p.ax.get_legend()
assert leg is None


def test_legend_is_removed_when_only_one_artist_is_left():
a = data_array(ndim=1)
b = 2.3 * a
f = pp.plot({'a': a, 'b': b})
ka, _ = f.view.artists.keys()
f.view.remove(ka)
assert f.ax.get_legend() is None


def test_legend_entry_is_removed_when_artist_is_removed():
a = data_array(ndim=1)
b = 2.3 * a
c = 0.8 * a
f = pp.plot({'a': a, 'b': b, 'c': c})
ka, _, _ = f.view.artists.keys()
f.view.remove(ka)
texts = f.ax.get_legend().get_texts()
assert len(texts) == 2
assert texts[0].get_text() == 'b'
assert texts[1].get_text() == 'c'


def test_with_string_coord_1d():
strings = ['a', 'b', 'c', 'd', 'e']
da = sc.DataArray(
Expand Down
Loading

0 comments on commit b079269

Please sign in to comment.