diff --git a/src/plopp/backends/common.py b/src/plopp/backends/common.py index f4b1e0f0..a2107e8b 100644 --- a/src/plopp/backends/common.py +++ b/src/plopp/backends/common.py @@ -1,10 +1,13 @@ # SPDX-License-Identifier: BSD-3-Clause # Copyright (c) 2023 Scipp contributors (https://github.com/scipp) +from typing import Literal + import numpy as np import scipp as sc from ..core.utils import merge_masks +from ..graphics.bbox import BoundingBox, axis_bounds def check_ndim(data: sc.DataArray, ndim: int, origin: str) -> None: @@ -68,3 +71,44 @@ def make_line_data(data: sc.DataArray, dim: str) -> dict: for array in (values, mask): array['y'] = np.concatenate([array['y'][0:1], array['y']]) return {'values': values, 'stddevs': error, 'mask': mask, 'hist': hist} + + +def make_line_bbox( + data: sc.DataArray, + dim: str, + errorbars: bool, + xscale: Literal['linear', 'log'], + yscale: Literal['linear', 'log'], +) -> BoundingBox: + """ + Calculate the bounding box of a line artist. + This includes the x and y bounds of the line and optionally the error bars. + + Parameters + ---------- + data: + The data array to extract values from. + dim: + The dimension along which to extract values. + errorbars: + Whether to include error bars in the bounding box. + xscale: + The scale of the x-axis. + yscale: + The scale of the y-axis. + """ + line_x = data.coords[dim] + sel = slice(None) + if data.masks: + sel = ~merge_masks(data.masks) + if set(sel.dims) != set(data.data.dims): + sel = sc.broadcast(sel, sizes=data.data.sizes).copy() + line_y = data.data[sel] + if errorbars: + stddevs = sc.stddevs(data.data[sel]) + line_y = sc.concat([line_y - stddevs, line_y + stddevs], dim) + + return BoundingBox( + **{**axis_bounds(('xmin', 'xmax'), line_x, xscale, pad=True)}, + **{**axis_bounds(('ymin', 'ymax'), line_y, yscale, pad=True)}, + ) diff --git a/src/plopp/backends/matplotlib/line.py b/src/plopp/backends/matplotlib/line.py index 33286d09..298964e0 100644 --- a/src/plopp/backends/matplotlib/line.py +++ b/src/plopp/backends/matplotlib/line.py @@ -10,9 +10,8 @@ from matplotlib.lines import Line2D from numpy.typing import ArrayLike -from ...core.utils import merge_masks -from ...graphics.bbox import BoundingBox, axis_bounds -from ..common import check_ndim, make_line_data +from ...graphics.bbox import BoundingBox +from ..common import check_ndim, make_line_bbox, make_line_data from .canvas import Canvas from .utils import parse_dicts_in_kwargs @@ -199,19 +198,24 @@ def color(self, val): artist.set_color(val) self._canvas.draw() - def bbox(self, xscale: Literal['linear', 'log'], yscale: Literal['linear', 'log']): + def bbox( + self, xscale: Literal['linear', 'log'], yscale: Literal['linear', 'log'] + ) -> BoundingBox: """ The bounding box of the line. - """ - line_x = self._data.coords[self._dim] - sel = ~merge_masks(self._data.masks) if self._data.masks else slice(None) - line_y = self._data.data[sel] - if self._error is not None: - stddevs = sc.stddevs(self._data.data[sel]) - line_y = sc.concat([line_y - stddevs, line_y + stddevs], self._dim) + This includes the x and y bounds of the line and optionally the error bars. - out = BoundingBox( - **{**axis_bounds(('xmin', 'xmax'), line_x, xscale, pad=True)}, - **{**axis_bounds(('ymin', 'ymax'), line_y, yscale, pad=True)}, + Parameters + ---------- + xscale: + The scale of the x-axis. + yscale: + The scale of the y-axis. + """ + return make_line_bbox( + data=self._data, + dim=self._dim, + errorbars=self._error is not None, + xscale=xscale, + yscale=yscale, ) - return out diff --git a/src/plopp/backends/plotly/line.py b/src/plopp/backends/plotly/line.py index d67f43dc..64780774 100644 --- a/src/plopp/backends/plotly/line.py +++ b/src/plopp/backends/plotly/line.py @@ -9,9 +9,8 @@ import scipp as sc from plotly.colors import qualitative as plotly_colors -from ...core.utils import merge_masks -from ...graphics.bbox import BoundingBox, axis_bounds -from ..common import check_ndim, make_line_data +from ...graphics.bbox import BoundingBox +from ..common import check_ndim, make_line_bbox, make_line_data from .canvas import Canvas @@ -224,20 +223,26 @@ def color(self): def color(self, val): self._line.line.color = val - def bbox(self, xscale: Literal['linear', 'log'], yscale: Literal['linear', 'log']): + def bbox( + self, xscale: Literal['linear', 'log'], yscale: Literal['linear', 'log'] + ) -> BoundingBox: """ The bounding box of the line. - """ - line_x = self._data.coords[self._dim] - sel = ~merge_masks(self._data.masks) if self._data.masks else slice(None) - line_y = self._data.data[sel] - if self._error is not None: - stddevs = sc.stddevs(self._data.data[sel]) - line_y = sc.concat([line_y - stddevs, line_y + stddevs], self._dim) + This includes the x and y bounds of the line and optionally the error bars. - out = BoundingBox( - **{**axis_bounds(('xmin', 'xmax'), line_x, xscale, pad=True)}, - **{**axis_bounds(('ymin', 'ymax'), line_y, yscale, pad=True)}, + Parameters + ---------- + xscale: + The scale of the x-axis. + yscale: + The scale of the y-axis. + """ + out = make_line_bbox( + data=self._data, + dim=self._dim, + errorbars=self._error is not None, + xscale=xscale, + yscale=yscale, ) if xscale == 'log': out.xmin = np.log10(out.xmin) diff --git a/tests/plotting/plot_1d_test.py b/tests/plotting/plot_1d_test.py index 6b100cf4..cbe3916e 100644 --- a/tests/plotting/plot_1d_test.py +++ b/tests/plotting/plot_1d_test.py @@ -542,3 +542,12 @@ def test_figure_has_only_unit_on_vertical_axis_for_multiple_curves(): assert a.name not in ylabel assert b.name not in ylabel assert c.name not in ylabel + + +def test_plot_1d_scalar_mask(): + da = sc.DataArray( + sc.ones(sizes={'x': 3}), + coords={'x': sc.arange('x', 3)}, + masks={'m': sc.scalar(False)}, + ) + _ = da.plot()