diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 376c4b9..e649de2 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -46,13 +46,12 @@ jobs: with: ref: ${{ github.event.pull_request.head.sha || github.ref }} - name: Install Conda environment with Micromamba - uses: mamba-org/provision-with-micromamba@v14 + uses: mamba-org/setup-micromamba@v1 with: environment-file: environment.yml environment-name: DEVELOP - channels: conda-forge - cache-env: true - extra-specs: | + cache-environment: true + create-args: >- python=3.10 - name: Install package run: | @@ -140,14 +139,13 @@ jobs: with: ref: ${{ github.event.pull_request.head.sha || github.ref }} - name: Install Conda environment with Micromamba - uses: mamba-org/provision-with-micromamba@v12 + uses: mamba-org/setup-micromamba@v1 with: environment-file: environment${{ matrix.extra }}.yml environment-name: DEVELOP${{ matrix.extra }} - channels: conda-forge - cache-env: true - cache-env-key: ubuntu-latest-${{ matrix.python-version }}${{ matrix.extra }}. - extra-specs: | + cache-environment: true + cache-environment-key: ubuntu-latest-${{ matrix.python-version }}${{ matrix.extra }}. + create-args: >- python=${{matrix.python-version }} - name: Install package run: | diff --git a/docs/examples/gallery/interactive/meteogram.ipynb b/docs/examples/gallery/interactive/meteogram.ipynb new file mode 100644 index 0000000..70eb271 --- /dev/null +++ b/docs/examples/gallery/interactive/meteogram.ipynb @@ -0,0 +1,103 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8795a812-8184-4e2b-b040-2893355634f3", + "metadata": {}, + "source": [ + "This example demonstrates how to create an interactive meteogram visualisation from a coverageJSON file." + ] + }, + { + "cell_type": "markdown", + "id": "b40bf27f-1a4f-44cc-b5fa-8d855059cca5", + "metadata": {}, + "source": [ + "
\n", + " Note: This notebook is rendered in many different ways depending on where you are viewing it (e.g. GitHub, Jupyter, readthedocs etc.). To maximise compatibility with many possible rendering methods, all interactive plots are rendered with chart.show(renderer=\"png\"), which removes all interactivity and only shows a PNG image render.

\n", + " If you are running this notebook in an interactive session yourself and would like to interact with the plots, remove the renderer=\"png\" argument from each call to chart.show().\n", + "
" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "fef7c299-f5d5-476c-b1d0-4d8d1499f6a5", + "metadata": {}, + "outputs": [], + "source": [ + "from earthkit.plots.interactive import Chart\n", + "import earthkit.data" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "92606a9c-846e-4c75-b780-b142669116a2", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "timeseries.json: 0%| | 0.00/690k [00:00=0.22.0", "pint", "matplotlib", + "plotly", "pyyaml", "numpy", "adjustText" diff --git a/src/earthkit/plots/_plugins.py b/src/earthkit/plots/_plugins.py index 2c27f5b..8050b29 100644 --- a/src/earthkit/plots/_plugins.py +++ b/src/earthkit/plots/_plugins.py @@ -1,10 +1,40 @@ +# Copyright 2024, European Centre for Medium Range Weather Forecasts. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + from importlib.metadata import entry_points from pathlib import Path def register_plugins(): plugins = dict() - for plugin in entry_points(group="earthkit.plots.plugins"): + + # Compatibility adjustment for Python 3.9 and earlier + all_entry_points = entry_points() + + if hasattr(all_entry_points, "select"): + # For Python 3.10 and above + plugin_entry_points = all_entry_points.select(group="earthkit.plots.plugins") + else: + # For Python 3.9 and below, access the group directly from the dictionary + # and ensure it defaults to an empty list if not found + plugin_entry_points = all_entry_points.get("earthkit.plots.plugins", []) + + # Additional handling for consistency in 3.9 by converting entry points if needed + if isinstance(plugin_entry_points, dict): + plugin_entry_points = plugin_entry_points.get("earthkit.plots.plugins", []) + + for plugin in plugin_entry_points: path = Path(plugin.load().__file__).parents[0] plugins[plugin.name] = { "identities": path / "identities", @@ -14,6 +44,7 @@ def register_plugins(): for key, value in plugins[plugin.name].items(): if not value.exists(): plugins[plugin.name][key] = None + return plugins diff --git a/src/earthkit/plots/components/subplots.py b/src/earthkit/plots/components/subplots.py index 376f2ae..a57ea7d 100644 --- a/src/earthkit/plots/components/subplots.py +++ b/src/earthkit/plots/components/subplots.py @@ -374,19 +374,6 @@ def _extract_plottables( x_values = source.x_values y_values = source.y_values - if method_name in ( - "contour", - "contourf", - "pcolormesh", - ) and not grids.is_structured(x_values, y_values): - x_values, y_values, z_values = grids.interpolate_unstructured( - x_values, - y_values, - z_values, - method=kwargs.pop("interpolation_method", "linear"), - ) - extract_domain = False - if every is not None: x_values = x_values[::every] y_values = y_values[::every] @@ -405,9 +392,23 @@ def _extract_plottables( warnings.warn( "The 'interpolation_method' argument is only valid for unstructured data." ) - mappable = getattr(style, method_name)( - self.ax, x_values, y_values, z_values, **kwargs - ) + try: + mappable = getattr(style, method_name)( + self.ax, x_values, y_values, z_values, **kwargs + ) + except TypeError as err: + if not grids.is_structured(x_values, y_values): + x_values, y_values, z_values = grids.interpolate_unstructured( + x_values, + y_values, + z_values, + method=kwargs.pop("interpolation_method", "linear"), + ) + mappable = getattr(style, method_name)( + self.ax, x_values, y_values, z_values, **kwargs + ) + else: + raise err self.layers.append(Layer(source, mappable, self, style)) return mappable @@ -947,6 +948,10 @@ def show(self): """Display the plot.""" return self.figure.show() + def save(self, *args, **kwargs): + """Save the plot to a file.""" + return self.figure.save(*args, **kwargs) + def thin_array(array, every=2): """ diff --git a/src/earthkit/plots/geo/grids.py b/src/earthkit/plots/geo/grids.py index 1db379d..85f0ef7 100644 --- a/src/earthkit/plots/geo/grids.py +++ b/src/earthkit/plots/geo/grids.py @@ -21,59 +21,109 @@ _NO_SCIPY = True -def is_structured(lat, lon, tol=1e-5): +def is_structured(x, y, tol=1e-5): + """ + Determines whether the x and y points form a structured grid. + + This function checks if the x and y coordinate arrays represent a structured + grid, i.e., a grid with consistent spacing between points. The function supports + 1D arrays (representing coordinates of a grid) and 2D arrays (representing the + actual grid coordinates) of x and y. + + Parameters + ---------- + x : array_like + A 1D or 2D array of x-coordinates. For example, this can be longitude or + the x-coordinate in a Cartesian grid. + y : array_like + A 1D or 2D array of y-coordinates. For example, this can be latitude or + the y-coordinate in a Cartesian grid. + tol : float, optional + Tolerance for floating-point comparison to account for numerical precision + errors when checking spacing consistency. The default is 1e-5. + + Returns + ------- + bool + True if the data represents a structured grid, i.e., the spacing between + consecutive points in both x and y is consistent. False otherwise. """ - Determines whether the latitude and longitude points form a structured grid. - Parameters: - - lat: A 1D or 2D array of latitude points. - - lon: A 1D or 2D array of longitude points. - - tol: Tolerance for floating-point comparison (default 1e-5). + x = np.asarray(x) + y = np.asarray(y) - Returns: - - True if the data is structured (grid), False if it's unstructured. - """ + # If both x and y are 1D arrays, ensure they can form a grid + if x.ndim == 1 and y.ndim == 1: + # Check if the number of points match (can form a meshgrid) + if len(x) * len(y) != x.size * y.size: + return False + + # Check consistent spacing in x and y + x_diff = np.diff(x) + y_diff = np.diff(y) - lat = np.asarray(lat) - lon = np.asarray(lon) + x_spacing_consistent = np.all(np.abs(x_diff - x_diff[0]) < tol) + y_spacing_consistent = np.all(np.abs(y_diff - y_diff[0]) < tol) - # Check if there are consistent spacing in latitudes and longitudes - unique_lat = np.unique(lat) - unique_lon = np.unique(lon) + return x_spacing_consistent and y_spacing_consistent - # Structured grid condition: the number of unique lat/lon values should multiply to the number of total points - if len(unique_lat) * len(unique_lon) == len(lat) * len(lon): - # Now check if the spacing is consistent - lat_diff = np.diff(unique_lat) - lon_diff = np.diff(unique_lon) + # If x and y are 2D arrays, verify they are structured as a grid + elif x.ndim == 2 and y.ndim == 2: + # Check if rows of x and y have consistent spacing along the grid lines + # x should vary only along one axis, y along the other axis - # Check if lat/lon differences are consistent - lat_spacing_consistent = np.all(np.abs(lat_diff - lat_diff[0]) < tol) - lon_spacing_consistent = np.all(np.abs(lon_diff - lon_diff[0]) < tol) + x_rows_consistent = np.all( + np.abs(np.diff(x, axis=1) - np.diff(x, axis=1)[:, 0:1]) < tol + ) + y_columns_consistent = np.all( + np.abs(np.diff(y, axis=0) - np.diff(y, axis=0)[0:1, :]) < tol + ) - return lat_spacing_consistent and lon_spacing_consistent + return x_rows_consistent and y_columns_consistent - # If the product of unique lat/lon values doesn't match total points, it's unstructured - return False + else: + # Invalid input, dimensions of x and y must match (either both 1D or both 2D) + return False def interpolate_unstructured(x, y, z, resolution=1000, method="linear"): """ - Interpolates unstructured data to a structured grid, handling NaNs in z-values - and preventing interpolation across large gaps. - - Parameters: - - x: 1D array of x-coordinates. - - y: 1D array of y-coordinates. - - z: 1D array of z values. - - resolution: The number of points along each axis for the structured grid. - - method: Interpolation method ('linear', 'nearest', 'cubic'). - - gap_threshold: The distance threshold beyond which interpolation is not performed (set to NaN). - - Returns: - - grid_x: 2D grid of x-coordinates. - - grid_y: 2D grid of y-coordinates. - - grid_z: 2D grid of interpolated z-values, with NaNs in large gap regions. + Interpolate unstructured data to a structured grid. + + This function takes unstructured (scattered) data points and interpolates them + to a structured grid, handling NaN values in `z` and providing options for + different interpolation methods. It creates a regular grid based on the given + resolution and interpolates the z-values from the unstructured points onto this grid. + + Parameters + ---------- + x : array_like + 1D array of x-coordinates. + y : array_like + 1D array of y-coordinates. + z : array_like + 1D array of z-values at each (x, y) point. + resolution : int, optional + The number of points along each axis for the structured grid. + Default is 1000. + method : {'linear', 'nearest', 'cubic'}, optional + The interpolation method to use. Default is 'linear'. + The methods supported are: + + - 'linear': Linear interpolation between points. + - 'nearest': Nearest-neighbor interpolation. + - 'cubic': Cubic interpolation, which may produce smoother results. + + Returns + ------- + grid_x : ndarray + 2D array representing the x-coordinates of the structured grid. + grid_y : ndarray + 2D array representing the y-coordinates of the structured grid. + grid_z : ndarray + 2D array of interpolated z-values at the grid points. NaNs may be + present in regions where interpolation was not possible (e.g., due to + large gaps in the data). """ if _NO_SCIPY: raise ImportError( diff --git a/src/earthkit/plots/interactive/__init__.py b/src/earthkit/plots/interactive/__init__.py new file mode 100644 index 0000000..f85a9f5 --- /dev/null +++ b/src/earthkit/plots/interactive/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2024, European Centre for Medium Range Weather Forecasts. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from earthkit.plots.interactive.charts import Chart + +__all__ = [ + "Chart", +] diff --git a/src/earthkit/plots/interactive/bar.py b/src/earthkit/plots/interactive/bar.py new file mode 100644 index 0000000..efacdd3 --- /dev/null +++ b/src/earthkit/plots/interactive/bar.py @@ -0,0 +1,23 @@ +# Copyright 2024, European Centre for Medium Range Weather Forecasts. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import plotly.graph_objects as go + +from earthkit.plots.interactive import inputs + + +@inputs.sanitise() +def bar(*args, **kwargs): + trace = go.Bar(*args, **kwargs) + return trace diff --git a/src/earthkit/plots/interactive/box.py b/src/earthkit/plots/interactive/box.py new file mode 100644 index 0000000..19b68d2 --- /dev/null +++ b/src/earthkit/plots/interactive/box.py @@ -0,0 +1,118 @@ +# Copyright 2024, European Centre for Medium Range Weather Forecasts. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import plotly.graph_objects as go + +from earthkit.plots.interactive import inputs + +THICKEST = 0.6 +THINNEST = 0.3 + +DEFAULT_QUANTILES = [0.05, 0.25, 0.5, 0.75, 0.95] + +DEFAULT_KWARGS = { + "line_color": "#6E78FA", + "fillcolor": "#B1B6FC", +} + + +@inputs.sanitise(multiplot=False) +def box(*args, quantiles=DEFAULT_QUANTILES, time_axis=0, **kwargs): + """ + Generate a set of box plot traces based on the provided data and quantiles. + + Parameters + ---------- + data : array-like or earthkit.data.FieldList + The data to be plotted. + + *args : tuple + Positional arguments passed to the plotly `go.Box` constructors. + + quantiles : list of float, optional + A list of quantiles to calculate for the data. The default is + [0.05, 0.25, 0.5, 0.75, 0.95]. Note that any number of quantiles + can be provided, but the default is based on the standard five-point + box plot. + + time_axis : int, optional + The axis along which to calculate the quantiles. The default is 0. + + **kwargs : dict + Additional keyword arguments passed to the `go.Box` constructor. + + Returns + ------- + list of plotly.graph_objects.Box + + Notes + ----- + - The width of the box plots is scaled based on the x-axis spacing. + - Extra boxes are added for quantiles beyond the standard five-point box plot. + - Hover information is included for quantile scatter points, showing the + quantile value and percentage. + """ + kwargs = {**DEFAULT_KWARGS, **kwargs} + + extra_boxes = (len(quantiles) - 5) // 2 + + quantile_values = np.quantile(kwargs.pop("y"), quantiles, axis=time_axis) + + x = kwargs["x"] + width = float(x[1] - x[0]) * 1e-06 + + traces = [] + traces.append( + go.Box( + *args, + lowerfence=quantile_values[0], + upperfence=quantile_values[-1], + q1=quantile_values[1], + q3=quantile_values[-2], + median=quantile_values[len(quantiles) // 2], + width=width * (THICKEST if not extra_boxes else THINNEST), + hoverinfo="skip", + **kwargs, + ) + ) + + for j in range(extra_boxes): + traces.append( + go.Box( + *args, + lowerfence=quantile_values[0], + upperfence=quantile_values[-1], + showwhiskers=False, + q1=quantile_values[1 + (j + 1)], + q3=quantile_values[-2 - (j + 1)], + median=quantile_values[len(quantiles) // 2], + width=width * THICKEST, + hoverinfo="skip", + **kwargs, + ) + ) + + for y, p in zip(quantile_values, quantiles): + traces.append( + go.Scatter( + y=y, + x=kwargs["x"], + mode="markers", + marker={"size": 0.00001, "color": kwargs.get("line_color", "#333333")}, + hovertemplate=f"%{{y:.2f}}P{p*100:g}%", + ) + ) + + return traces diff --git a/src/earthkit/plots/interactive/charts.py b/src/earthkit/plots/interactive/charts.py new file mode 100644 index 0000000..2973436 --- /dev/null +++ b/src/earthkit/plots/interactive/charts.py @@ -0,0 +1,314 @@ +# Copyright 2024, European Centre for Medium Range Weather Forecasts. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from plotly.subplots import make_subplots + +from earthkit.plots.interactive import bar, box, inputs, line + +DEFAULT_LAYOUT = { + "colorway": [ + "#636EFA", + "#EF553B", + "#00CC96", + "#AB63FA", + "#FFA15A", + "#19D3F3", + "#FF6692", + "#B6E880", + "#FF97FF", + "#FECB52", + ], + "hovermode": "x", + "plot_bgcolor": "white", + "xaxis": { + "gridwidth": 1, + "showgrid": False, + "showline": False, + "zeroline": False, + }, + "yaxis": { + "linecolor": "black", + "gridcolor": "#EEEEEE", + "showgrid": True, + "showline": True, + "zeroline": False, + }, + "height": 750, + "showlegend": False, +} + + +class Chart: + """ + A class for creating and managing multi-subplot interactive charts using Plotly. + + Parameters + ---------- + rows : int, optional + Number of rows in the chart grid. Default is 1. + columns : int, optional + Number of columns in the chart grid. Default is 1. + **kwargs : dict + Additional arguments passed to `plotly.subplots.make_subplots`. + """ + + def __init__(self, rows=None, columns=None, **kwargs): + self._rows = rows + self._columns = columns + + self._fig = None + self._subplots = [] + self._subplots_kwargs = kwargs + self._subplot_titles = None + self._subplot_y_titles = None + self._subplot_x_titles = None + self._layout_override = dict() + + def set_subplot_titles(method): + def wrapper(self, *args, **kwargs): + if self._subplot_titles is None: + if args: + try: + ds = inputs.to_xarray(args[0]) + except Exception: + pass + else: + self._subplot_titles = list(ds.data_vars) + titles = [ + ds[data_var].attrs.get("units", "") + for data_var in ds.data_vars + ] + if kwargs.get("y") is not None: + self._subplot_x_titles = titles + else: + self._subplot_y_titles = titles + return method(self, *args, **kwargs) + + return wrapper + + @property + def fig(self): + """ + The Plotly figure object representing the chart. + """ + if self._fig is None: + self._fig = make_subplots( + rows=self.rows, + cols=self.columns, + subplot_titles=self._subplot_titles, + **self._subplots_kwargs, + ) + return self._fig + + @property + def rows(self): + """The number of rows in the chart grid.""" + if self._rows is None: + self._rows = 1 + return self._rows + + @property + def columns(self): + """The number of columns in the chart grid.""" + if self._columns is None: + self._columns = 1 + return self._columns + + def add_trace(self, *args, **kwargs): + """ + Adds a trace to the chart at the appropriate location. + + Parameters + ---------- + *args : tuple + Positional arguments passed to `plotly.graph_objects.Figure.add_trace`. + **kwargs : dict + Keyword arguments passed to `plotly.graph_objects.Figure.add_trace`. + """ + self.fig.add_trace(*args, **kwargs) + + @set_subplot_titles + def line(self, *args, **kwargs): + """ + Adds a line plot to the chart. + + Parameters + ---------- + data : array-like or earthkit.data.FieldList + The data to be plotted. + *args : tuple + Positional arguments passed to the line plot generation function. + **kwargs : dict + Additional options for customizing the line plot. + + Notes + ----- + Line plots are added as individual traces to each subplot. + Titles are inferred from data attributes if not provided. + """ + traces = line.line(*args, **kwargs) + for i, trace in enumerate(traces): + if isinstance(trace, list): + if self._fig is None: + self._rows = self._rows or len(traces) + self._columns = self._columns or 1 + for sub_trace in trace: + self.add_trace(sub_trace, row=i + 1, col=1) + else: + self.add_trace(trace) + + @set_subplot_titles + def box(self, *args, **kwargs): + """ + Generate a set of box plot traces based on the provided data and quantiles. + + Parameters + ---------- + data : array-like or earthkit.data.FieldList + The data to be plotted. + + *args : tuple + Positional arguments passed to the plotly `go.Box` constructors. + + quantiles : list of float, optional + A list of quantiles to calculate for the data. The default is + [0.05, 0.25, 0.5, 0.75, 0.95]. Note that any number of quantiles + can be provided, but the default is based on the standard five-point + box plot. + + time_axis : int, optional + The axis along which to calculate the quantiles. The default is 0. + + **kwargs : dict + Additional keyword arguments passed to the `go.Box` constructor. + + Returns + ------- + list of plotly.graph_objects.Box + + Notes + ----- + - The width of the box plots is scaled based on the x-axis spacing. + - Extra boxes are added for quantiles beyond the standard five-point box plot. + - Hover information is included for quantile scatter points, showing the + quantile value and percentage. + """ + traces = box.box(*args, **kwargs) + for i, trace in enumerate(traces): + if isinstance(trace, list): + if self._fig is None: + self._rows = self._rows or len(traces) + self._columns = self._columns or 1 + for sub_trace in trace: + if not isinstance(sub_trace, (list, tuple)): + sub_trace = [sub_trace] + for actual_trace in sub_trace: + self.add_trace(actual_trace, row=i + 1, col=1) + else: + self.add_trace(trace) + + @set_subplot_titles + def bar(self, *args, **kwargs): + """ + Adds a bar plot to the chart. + + Parameters + ---------- + data : array-like or earthkit.data.FieldList + The data to be plotted. + *args : tuple + Positional arguments passed to the bar plot generation function. + **kwargs : dict + Additional options for customizing the bar plot. + + Notes + ----- + Bar plots are added as individual traces to each subplot. + Titles are inferred from data attributes if not provided. + """ + traces = bar.bar(*args, **kwargs) + for i, trace in enumerate(traces): + if isinstance(trace, list): + if self._fig is None: + self._rows = self._rows or len(traces) + self._columns = self._columns or 1 + for sub_trace in trace: + self.add_trace(sub_trace, row=i + 1, col=1) + else: + self.add_trace(trace) + + def title(self, title): + """ + Set the overall chart title. + + Parameters + ---------- + title : str + The title to display at the top of the chart. + """ + self._layout_override["title"] = title + + def show(self, *args, **kwargs): + """ + Display the chart. + + Parameters + ---------- + *args : tuple + Additional arguments for `plotly.graph_objects.Figure.show`. + renderer : str, optional + The renderer to use for displaying the chart. The default is "browser". + For static plots, use "png". + **kwargs : dict + Additional options for rendering the chart. + + Returns + ------- + None + """ + layout = { + **DEFAULT_LAYOUT, + **self._layout_override, + } + # Temporary fix to remove _parent keys from nested dictionaries + for k in layout: + if isinstance(layout[k], dict): + layout[k] = { + k2: v for k2, v in layout[k].items() if not k2.startswith("_") + } + self.fig.update_layout(**layout) + for i in range(self.rows * self.columns): + y_key = f"yaxis{i+1 if i>0 else ''}" + x_key = f"xaxis{i+1 if i>0 else ''}" + if self._subplot_x_titles: + self.fig.update_layout( + **{ + y_key: layout["yaxis"], + x_key: { + **layout["xaxis"], + **{"title": self._subplot_x_titles[i]}, + }, + } + ) + if self._subplot_y_titles: + self.fig.update_layout( + **{ + x_key: layout["xaxis"], + y_key: { + **layout["yaxis"], + **{"title": self._subplot_y_titles[i]}, + }, + } + ) + return self.fig.show(*args, **kwargs) diff --git a/src/earthkit/plots/interactive/inputs.py b/src/earthkit/plots/interactive/inputs.py new file mode 100644 index 0000000..37883a8 --- /dev/null +++ b/src/earthkit/plots/interactive/inputs.py @@ -0,0 +1,148 @@ +# Copyright 2023, European Centre for Medium Range Weather Forecasts. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +import earthkit.data +import numpy as np + +from earthkit.plots.interactive import times + +# from earthkit.plots.schemas import schema + + +AXES = ["x", "y"] + + +def _earthkitify(data): + if isinstance(data, (list, tuple)): + data = np.array(data) + if not isinstance(data, earthkit.data.core.Base): + data = earthkit.data.from_object(data) + return data + + +def to_xarray(data): + return _earthkitify(data).to_xarray().squeeze() + + +def to_pandas(data): + try: + return _earthkitify(data).to_pandas() + except NotImplementedError: + return _earthkitify(data).to_xarray().squeeze().to_pandas() + + +def to_numpy(data): + return _earthkitify(data).to_numpy() + + +def sanitise(axes=("x", "y"), multiplot=True): + def decorator(function): + def wrapper( + data=None, + *args, + time_frequency=None, + time_aggregation="mean", + aggregation=None, + deaccumulate=False, + **kwargs, + ): + time_axis = kwargs.pop("time_axis", 0) + traces = [] + if data is not None: + ds = to_xarray(data) + time_dim = times.guess_time_dim(ds) + data_vars = list(ds.data_vars) + if time_frequency is not None: + if isinstance(time_aggregation, (list, tuple)): + for i, var_name in enumerate(data_vars): + ds[var_name] = getattr( + ds[var_name].resample(**{time_dim: time_frequency}), + time_aggregation[i], + )() + else: + ds = getattr( + ds.resample(**{time_dim: time_frequency}), time_aggregation + )() + time_axis = 1 + if aggregation is not None: + ds = getattr(ds, aggregation)(dim=times.guess_non_time_dim(ds)) + if "name" not in kwargs: + kwargs["name"] = aggregation + if deaccumulate: + if isinstance(deaccumulate, str): + ds[deaccumulate] = ds[deaccumulate].diff(dim=time_dim) + else: + ds = ds.diff(dim=time_dim) + if len(data_vars) > 1: + repeat_kwargs = { + k: v for k, v in kwargs.items() if k != "time_frequency" + } + repeat_kwargs + return [ + wrapper( + ds[data_var], *args, time_axis=time_axis, **repeat_kwargs + ) + for data_var in data_vars + ] + if len(ds.dims) == 2 and multiplot: + expand_dim = times.guess_non_time_dim(ds) + for i in range(len(ds[expand_dim])): + kwargs["name"] = f"{expand_dim}={ds[expand_dim][i].item()}" + trace_kwargs = get_xarray_kwargs( + ds.isel(**{expand_dim: i}), axes, kwargs + ) + traces.append(function(*args, **trace_kwargs)) + else: + trace_kwargs = get_xarray_kwargs(ds, axes, kwargs) + if not multiplot: + trace_kwargs["time_axis"] = time_axis + traces.append(function(*args, **trace_kwargs)) + else: + traces.append(function(*args, **kwargs)) + return traces + + return wrapper + + return decorator + + +def get_xarray_kwargs(data, axes, kwargs): + data = to_xarray(data) + kwargs = kwargs.copy() + data_vars = list(data.data_vars) + dim = list(data.dims)[-1] + + axis_attrs = dict() + assigned_attrs = [ + kwargs.get(axis).split(".")[-1] for axis in axes if axis in kwargs + ] + for axis in axes: + attr = kwargs.get(axis) + if attr is None: + if dim not in list(axis_attrs.values()) + assigned_attrs: + attr = dim + else: + attr = data_vars[0] + if len(data_vars) > 1: + warnings.warn( + f"dataset contains more than one data variable; " + f"variable '{attr}' has been selected for plotting" + ) + + kwargs[axis] = data[attr].values + axis_attrs[axis] = attr + + return kwargs diff --git a/src/earthkit/plots/interactive/line.py b/src/earthkit/plots/interactive/line.py new file mode 100644 index 0000000..3564576 --- /dev/null +++ b/src/earthkit/plots/interactive/line.py @@ -0,0 +1,24 @@ +# Copyright 2024, European Centre for Medium Range Weather Forecasts. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import plotly.graph_objects as go + +from earthkit.plots.interactive import inputs + + +# @schema.line.apply() +@inputs.sanitise() +def line(*args, **kwargs): + trace = go.Scatter(*args, **kwargs) + return trace diff --git a/src/earthkit/plots/interactive/times.py b/src/earthkit/plots/interactive/times.py new file mode 100644 index 0000000..2928052 --- /dev/null +++ b/src/earthkit/plots/interactive/times.py @@ -0,0 +1,40 @@ +# Copyright 2024, European Centre for Medium Range Weather Forecasts. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import warnings + +TIME_DIMS = ["time", "t", "month"] + + +def guess_time_dim(data): + with warnings.catch_warnings(): + warnings.simplefilter("ignore") + dims = dict(data.squeeze().dims) + for dim in TIME_DIMS: + if dim in dims: + return dim + + +def guess_non_time_dim(data): + dims = list(data.squeeze().dims) + for dim in TIME_DIMS: + if dim in dims: + dims.pop(dims.index(dim)) + break + + if len(dims) == 1: + return list(dims)[0] + + else: + raise ValueError("could not identify single dim over which to aggregate") diff --git a/src/earthkit/plots/interactive/utils.py b/src/earthkit/plots/interactive/utils.py new file mode 100644 index 0000000..a80ec2b --- /dev/null +++ b/src/earthkit/plots/interactive/utils.py @@ -0,0 +1,59 @@ +# Copyright 2023, European Centre for Medium Range Weather Forecasts. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import collections + + +def recursive_dict_update(original_dict, update_dict): + """ + Recursively update a dictionary with keys and values from another dictionary. + + Parameters + ---------- + original_dict : dict + The original dictionary to be updated (in place). + update_dict : dict + The dictionary containing keys to be updated in the original dictionary. + """ + for k, v in update_dict.items(): + if isinstance(v, collections.abc.Mapping): + original_dict[k] = recursive_dict_update(original_dict.get(k, {}), v) + else: + original_dict[k] = v + return original_dict + + +def list_to_human(iterable, conjunction="and", oxford_comma=False): + """ + Convert an iterable to a human-readable string. + + Parameters + ---------- + iterable : list or tuple + The list of strings to convert to a single natural language string. + conjunction : str, optional + The conjunction with which to join the last two elements of the list, + for example "and" (default). + oxford_comma : bool, optional + If `True`, an "Oxford comma" will be added before the conjunction when + there are three or more elements in the list. Default is `False`. + """ + list_of_strs = [str(item) for item in iterable] + + if len(list_of_strs) > 2: + list_of_strs = [", ".join(list_of_strs[:-1]), list_of_strs[-1]] + if oxford_comma: + list_of_strs[0] += "," + + return f" {conjunction} ".join(list_of_strs) diff --git a/src/earthkit/plots/metadata/units.py b/src/earthkit/plots/metadata/units.py index f608eae..f94fdae 100644 --- a/src/earthkit/plots/metadata/units.py +++ b/src/earthkit/plots/metadata/units.py @@ -116,7 +116,10 @@ def format_units(units, exponential_notation=False): >>> format_units("kg m-2") "$kg m^{-2}$" """ - latex_str = f"{_pintify(units):~L}" + units = _pintify(units) + if units.dimensionless: + return "dimensionless" + latex_str = f"{units:~L}" if exponential_notation: raise NotImplementedError("Exponential notation is not yet supported.") return f"${latex_str}$" diff --git a/src/earthkit/plots/sources/earthkit.py b/src/earthkit/plots/sources/earthkit.py index 913f11b..e63916e 100644 --- a/src/earthkit/plots/sources/earthkit.py +++ b/src/earthkit/plots/sources/earthkit.py @@ -133,9 +133,17 @@ def extract_xy(self): ) points = get_points(1) else: - points = self.data.to_points(flatten=False) - x = points["x"] - y = points["y"] + try: + points = self.data.to_points(flatten=False) + x = points["x"] + y = points["y"] + except ValueError: + latlon = self.data.to_latlon(flatten=False) + lat = latlon["lat"] + lon = latlon["lon"] + transformed = self.crs.transform_points(ccrs.PlateCarree(), lon, lat) + x = transformed[:, :, 0] + y = transformed[:, :, 1] return x, y def extract_x(self): diff --git a/src/earthkit/plots/sources/xarray.py b/src/earthkit/plots/sources/xarray.py index 30e364b..290b2ef 100644 --- a/src/earthkit/plots/sources/xarray.py +++ b/src/earthkit/plots/sources/xarray.py @@ -14,6 +14,7 @@ from functools import cached_property +import numpy as np import pandas as pd from earthkit.plots import identifiers @@ -82,7 +83,10 @@ def metadata(self, key, default=None): def datetime(self): """Get the datetime of the data.""" - datetimes = [pd.to_datetime(dt).to_pydatetime() for dt in self.data.time.values] + datetimes = [ + pd.to_datetime(dt).to_pydatetime() + for dt in np.atleast_1d(self.data.time.values) + ] return { "base_time": datetimes, "valid_time": datetimes, @@ -175,10 +179,7 @@ def crs(self): def x_values(self): """The x values of the data.""" super().x_values - x = self.data[self._x].values - if self.extract_x() in identifiers.LONGITUDE and (max(abs(x)) > 180): - x -= 180 - return x + return self.data[self._x].values @cached_property def y_values(self): diff --git a/src/earthkit/plots/styles/legends.py b/src/earthkit/plots/styles/legends.py index e5a0106..19bef4a 100644 --- a/src/earthkit/plots/styles/legends.py +++ b/src/earthkit/plots/styles/legends.py @@ -13,7 +13,7 @@ # limitations under the License. -DEFAULT_LEGEND_LABEL = "" # {variable_name} ({units})" +DEFAULT_LEGEND_LABEL = "{variable_name} ({units})" _DISJOINT_LEGEND_LOCATIONS = { "bottom": { @@ -55,7 +55,10 @@ def colorbar(layer, *args, shrink=0.8, aspect=35, ax=None, **kwargs): Any keyword arguments accepted by `matplotlib.figures.Figure.colorbar`. """ label = kwargs.pop("label", DEFAULT_LEGEND_LABEL) - label = layer.format_string(label) + try: + label = layer.format_string(label) + except (AttributeError, ValueError, KeyError): + label = "" kwargs = {**layer.style._legend_kwargs, **kwargs} kwargs.setdefault("format", lambda x, _: f"{x:g}") diff --git a/src/earthkit/plots/styles/levels.py b/src/earthkit/plots/styles/levels.py index 5527ddc..730dd2b 100644 --- a/src/earthkit/plots/styles/levels.py +++ b/src/earthkit/plots/styles/levels.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import math - import numpy as np from earthkit.plots.schemas import schema @@ -50,28 +48,28 @@ def auto_range(data, divergence_point=None, n_levels=schema.default_style_levels if divergence_point is not None: max_diff = max(max_value - divergence_point, divergence_point - min_value) - max_value = divergence_point + max_diff - min_value = divergence_point - max_diff + max_value = max_diff + min_value = -max_diff data_range = max_value - min_value initial_bin = data_range / n_levels - magnitude = 10 ** (math.floor(math.log(initial_bin, 10))) + magnitude = 10 ** (np.floor(np.log10(initial_bin))) bin_width = initial_bin - (initial_bin % -magnitude) - start = min_value - (min_value % magnitude) - - levels = np.arange( - start, - start + (bin_width * n_levels) + bin_width, - bin_width, - ).tolist() + min_value -= min_value % bin_width + max_value -= max_value % -bin_width - while levels[-2] >= max_value: - levels = levels[:-1] + if divergence_point is not None: + min_value += divergence_point + max_value += divergence_point - return levels + return np.linspace( + min_value, + max_value, + n_levels + 1, + ).tolist() def step_range(data, step, reference=None): diff --git a/tests/interactive/test_Chart.py b/tests/interactive/test_Chart.py new file mode 100644 index 0000000..3f3cf69 --- /dev/null +++ b/tests/interactive/test_Chart.py @@ -0,0 +1,48 @@ +from plotly.graph_objects import Figure + +from earthkit.plots.interactive import ( + Chart, # Replace 'module' with the module containing the Chart class. +) + + +def test_chart_initialization(): + """Test initialization of the Chart class with default values.""" + chart = Chart() + assert chart.rows == 1 + assert chart.columns == 1 + assert chart._fig is None + assert chart._layout_override == {} + assert chart._subplots_kwargs == {} + + +def test_chart_initialization_with_args(): + """Test initialization of the Chart class with custom rows and columns.""" + chart = Chart(rows=3, columns=2) + assert chart.rows == 3 + assert chart.columns == 2 + + +def test_chart_fig_creation(): + """Test the creation of the figure property.""" + chart = Chart(rows=2, columns=3) + fig = chart.fig + assert isinstance(fig, Figure) + assert len(fig.layout.annotations) == 0 # No subplot titles initially. + + +def test_chart_add_trace(): + """Test adding a trace to the chart.""" + chart = Chart(rows=1, columns=1) + trace_data = {"x": [1, 2, 3], "y": [4, 5, 6], "type": "scatter"} + chart.add_trace(trace_data) + assert len(chart.fig.data) == 1 + assert chart.fig.data[0].type == "scatter" + assert chart.fig.data[0].x == (1, 2, 3) + assert chart.fig.data[0].y == (4, 5, 6) + + +def test_chart_title(): + """Test setting the chart title.""" + chart = Chart(rows=1, columns=1) + chart.title("Test Chart Title") + assert chart._layout_override["title"] == "Test Chart Title" diff --git a/tests/interactive/test_inputs.py b/tests/interactive/test_inputs.py new file mode 100644 index 0000000..8c649a2 --- /dev/null +++ b/tests/interactive/test_inputs.py @@ -0,0 +1,62 @@ +import numpy as np +import pandas as pd +import pytest +import xarray as xr + +from earthkit.plots.interactive.inputs import _earthkitify, to_numpy, to_xarray + + +class MockEarthkitData: + def __init__(self, data): + self.data = data + + def to_xarray(self): + return xr.DataArray(self.data) + + def to_pandas(self): + return pd.Series(self.data) + + def to_numpy(self): + return np.array(self.data) + + +@pytest.fixture +def mock_earthkit(monkeypatch): + """Mock the earthkit.data.from_object function to return mock data.""" + + def mock_from_object(data): + return MockEarthkitData(data) + + monkeypatch.setattr("earthkit.data.from_object", mock_from_object) + + +def test_earthkitify_with_list(mock_earthkit): + """Test _earthkitify with a list input.""" + data = [1, 2, 3] + result = _earthkitify(data) + assert isinstance(result, MockEarthkitData) + assert np.array_equal(result.to_numpy(), np.array(data)) + + +def test_earthkitify_with_numpy(mock_earthkit): + """Test _earthkitify with a numpy array.""" + data = np.array([1, 2, 3]) + result = _earthkitify(data) + assert isinstance(result, MockEarthkitData) + assert np.array_equal(result.to_numpy(), data) + + +def test_to_xarray(mock_earthkit): + """Test to_xarray conversion.""" + data = [1, 2, 3] + result = to_xarray(data) + assert isinstance(result, xr.DataArray) + assert np.array_equal(result.values, np.array(data)) + + +def test_to_numpy(mock_earthkit): + """Test to_numpy conversion.""" + data = [1, 2, 3] + result = to_numpy(data) + assert isinstance(result, np.ndarray) + assert np.array_equal(result, np.array(data))