diff --git a/CHANGELOG.md b/CHANGELOG.md index 21ef6ff3..287d76ea 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -21,6 +21,7 @@ Keep it human-readable, your future self will thank you! - Save entire config in mlflow ### Added - Included more loss functions and allowed configuration [#70](https://github.com/ecmwf/anemoi-training/pull/70) +- Include option to use datashader and optimised asyncronohous callbacks [#102](https://github.com/ecmwf/anemoi-training/pull/102) - Fix that applies the metric_ranges in the post-processed variable space [#116](https://github.com/ecmwf/anemoi-training/pull/116) - Allow updates to scalars [#137](https://github.com/ecmwf/anemoi-training/pulls/137) - Add without subsetting in ScaleTensor diff --git a/docs/modules/diagnostics.rst b/docs/modules/diagnostics.rst index 28eac7c7..4364e683 100644 --- a/docs/modules/diagnostics.rst +++ b/docs/modules/diagnostics.rst @@ -51,12 +51,32 @@ parameters to plot, as well as the plotting frequency, and asynchronosity. Setting ``config.diagnostics.plot.asynchronous``, means that the model -training doesn't stop whilst the callbacks are being evaluated) +training doesn't stop whilst the callbacks are being evaluated. This is +useful for large models where the plotting can take a long time. The +plotting module uses asynchronous callbacks via `asyncio` and +`concurrent.futures.ThreadPoolExecutor` to handle plotting tasks without +blocking the main application. A dedicated event loop runs in a separate +background thread, allowing plotting tasks to be offloaded to worker +threads. This setup keeps the main thread responsive, handling +plot-related tasks asynchronously and efficiently in the background. + +There is an additional flag in the plotting callbacks to control the +rendering method for geospatial plots, offering a trade-off between +performance and detail. When `datashader` is set to True, Datashader is +used for rendering, which accelerates plotting through efficient +hexbining, particularly useful for large datasets. This approach can +produce smoother-looking plots due to the aggregation of data points. If +`datashader` is set to False, matplotlib.scatter is used, which provides +sharper and more detailed visuals but may be slower for large datasets. + +**Note** - this asynchronous behaviour is only available for the +plotting callbacks. .. code:: yaml plot: asynchronous: True # Whether to plot asynchronously + datashader: True # Whether to use datashader for plotting (faster) frequency: # Frequency of the plotting batch: 750 epoch: 5 diff --git a/pyproject.toml b/pyproject.toml index f3e730d3..8d685acd 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -44,6 +44,7 @@ dependencies = [ "anemoi-graphs>=0.4", "anemoi-models>=0.3", "anemoi-utils[provenance]>=0.4.4", + "datashader>=0.16.3", "einops>=0.6.1", "hydra-core>=1.3", "matplotlib>=3.7.1", diff --git a/src/anemoi/training/config/diagnostics/plot/detailed.yaml b/src/anemoi/training/config/diagnostics/plot/detailed.yaml index b759c17b..d1ac8b0f 100644 --- a/src/anemoi/training/config/diagnostics/plot/detailed.yaml +++ b/src/anemoi/training/config/diagnostics/plot/detailed.yaml @@ -1,4 +1,5 @@ asynchronous: True # Whether to plot asynchronously +datashader: True # Choose which technique to use for plotting frequency: # Frequency of the plotting batch: 750 epoch: 5 diff --git a/src/anemoi/training/config/diagnostics/plot/simple.yaml b/src/anemoi/training/config/diagnostics/plot/simple.yaml index 2a987ccb..63c805a2 100644 --- a/src/anemoi/training/config/diagnostics/plot/simple.yaml +++ b/src/anemoi/training/config/diagnostics/plot/simple.yaml @@ -1,4 +1,5 @@ asynchronous: True # Whether to plot asynchronously +datashader: True # Choose which technique to use for plotting frequency: # Frequency of the plotting batch: 750 epoch: 10 diff --git a/src/anemoi/training/data/datamodule.py b/src/anemoi/training/data/datamodule.py index 0d3d1b3f..303266fc 100644 --- a/src/anemoi/training/data/datamodule.py +++ b/src/anemoi/training/data/datamodule.py @@ -140,8 +140,7 @@ def ds_train(self) -> NativeGridDataset: @cached_property def ds_valid(self) -> NativeGridDataset: - r = self.rollout - r = max(r, self.config.dataloader.get("validation_rollout", 1)) + r = max(self.rollout, self.config.dataloader.get("validation_rollout", 1)) assert self.config.dataloader.training.end < self.config.dataloader.validation.start, ( f"Training end date {self.config.dataloader.training.end} is not before" diff --git a/src/anemoi/training/diagnostics/callbacks/plot.py b/src/anemoi/training/diagnostics/callbacks/plot.py index 869a69fb..171eb840 100644 --- a/src/anemoi/training/diagnostics/callbacks/plot.py +++ b/src/anemoi/training/diagnostics/callbacks/plot.py @@ -7,13 +7,13 @@ # granted to it by virtue of its status as an intergovernmental organisation # nor does it submit to any jurisdiction. -# ruff: noqa: ANN001 from __future__ import annotations +import asyncio import copy import logging -import sys +import threading import time import traceback from abc import ABC @@ -23,8 +23,6 @@ from functools import cached_property from pathlib import Path from typing import TYPE_CHECKING -from typing import Any -from typing import Callable import matplotlib.patches as mpatches import matplotlib.pyplot as plt @@ -43,33 +41,14 @@ from anemoi.training.losses.weightedloss import BaseWeightedLoss if TYPE_CHECKING: + from typing import Any + import pytorch_lightning as pl from omegaconf import OmegaConf LOGGER = logging.getLogger(__name__) -class ParallelExecutor(ThreadPoolExecutor): - """Wraps parallel execution and provides accurate information about errors. - - Extends ThreadPoolExecutor to preserve the original traceback and line number. - - Reference: https://stackoverflow.com/questions/19309514/getting-original-line- - number-for-exception-in-concurrent-futures/24457608#24457608 - """ - - def submit(self, fn: Any, *args, **kwargs) -> Callable: - """Submits the wrapped function instead of `fn`.""" - return super().submit(self._function_wrapper, fn, *args, **kwargs) - - def _function_wrapper(self, fn: Any, *args: list, **kwargs: dict) -> Callable: - """Wraps `fn` in order to preserve the traceback of any kind of.""" - try: - return fn(*args, **kwargs) - except Exception as exc: - raise sys.exc_info()[0](traceback.format_exc()) from exc - - class BasePlotCallback(Callback, ABC): """Factory for creating a callback that plots data to Experiment Logging.""" @@ -93,11 +72,21 @@ def __init__(self, config: OmegaConf) -> None: self.plot = self._plot self._executor = None + self._error: BaseException = None + self.datashader_plotting = config.diagnostics.plot.datashader if self.config.diagnostics.plot.asynchronous: - self._executor = ParallelExecutor(max_workers=1) - self._error: BaseException | None = None + LOGGER.info("Setting up asynchronous plotting ...") self.plot = self._async_plot + self._executor = ThreadPoolExecutor(max_workers=1) + self.loop_thread = threading.Thread(target=self.start_event_loop, daemon=True) + self.loop_thread.start() + + def start_event_loop(self) -> None: + """Start the event loop in a separate thread.""" + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + self.loop.run_forever() @rank_zero_only def _output_figure( @@ -113,27 +102,48 @@ def _output_figure( save_path = Path( self.save_basedir, "plots", - f"{tag}_epoch{epoch:03d}.png", + f"{tag}_epoch{epoch:03d}.jpg", ) save_path.parent.mkdir(parents=True, exist_ok=True) - fig.savefig(save_path, dpi=100, bbox_inches="tight") + fig.canvas.draw() + image_array = np.array(fig.canvas.renderer.buffer_rgba()) + plt.imsave(save_path, image_array, dpi=100) if self.config.diagnostics.log.wandb.enabled: import wandb logger.experiment.log({exp_log_tag: wandb.Image(fig)}) - if self.config.diagnostics.log.mlflow.enabled: run_id = logger.run_id logger.experiment.log_artifact(run_id, str(save_path)) plt.close(fig) # cleanup + @rank_zero_only + def _plot_with_error_catching(self, trainer: pl.Trainer, args: Any, kwargs: Any) -> None: + """To execute the plot function but ensuring we catch any errors.""" + try: + self._plot(trainer, *args, **kwargs) + except BaseException: + import os + + LOGGER.exception(traceback.format_exc()) + os._exit(1) # to force exit when sanity val steps are used + def teardown(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None: - """Method is called to close the threads.""" + """Teardown the callback.""" del trainer, pl_module, stage # unused + LOGGER.info("Teardown of the Plot Callback ...") + if self._executor is not None: - self._executor.shutdown(wait=True) + LOGGER.info("waiting and shutting down the executor ...") + self._executor.shutdown(wait=False, cancel_futures=True) + + self.loop.call_soon_threadsafe(self.loop.stop) + self.loop_thread.join() + # Step 3: Close the asyncio event loop + self.loop_thread._stop() + self.loop_thread._delete() def apply_output_mask(self, pl_module: pl.LightningModule, data: torch.Tensor) -> torch.Tensor: if hasattr(pl_module, "output_mask") and pl_module.output_mask is not None: @@ -147,31 +157,39 @@ def _plot( self, trainer: pl.Trainer, pl_module: pl.LightningModule, - *args, - **kwargs, + *args: Any, + **kwargs: Any, ) -> None: """Plotting function to be implemented by subclasses.""" + # Async function to run the plot function in the background thread + async def submit_plot(self, trainer: pl.Trainer, *args: Any, **kwargs: Any) -> None: + """Async function or coroutine to schedule the plot function.""" + loop = asyncio.get_running_loop() + # run_in_executor doesn't support keyword arguments, + await loop.run_in_executor( + self._executor, + self._plot_with_error_catching, + trainer, + args, + kwargs, + ) # because loop.run_in_executor expects positional arguments, not keyword arguments + @rank_zero_only def _async_plot( self, trainer: pl.Trainer, - *args: list, - **kwargs: dict, + *args: Any, + **kwargs: Any, ) -> None: - """To execute the plot function but ensuring we catch any errors.""" - future = self._executor.submit( - self._plot, - trainer, - *args, - **kwargs, - ) - # otherwise the error won't be thrown till the validation epoch is finished - try: - future.result() - except Exception: - LOGGER.exception("Critical error occurred in asynchronous plots.") - sys.exit(1) + """Run the plot function asynchronously. + + This is the function that is called by the callback. It schedules the plot + function to run in the background thread. Since we have an event loop running in + the background thread, we need to schedule the plot function to run in that + loop. + """ + asyncio.run_coroutine_threadsafe(self.submit_plot(trainer, *args, **kwargs), self.loop) class BasePerBatchPlotCallback(BasePlotCallback): @@ -192,26 +210,12 @@ def __init__(self, config: OmegaConf, every_n_batches: int | None = None): super().__init__(config) self.every_n_batches = every_n_batches or self.config.diagnostics.plot.frequency.batch - @abstractmethod @rank_zero_only - def _plot( + def on_validation_batch_end( self, trainer: pl.Trainer, pl_module: pl.LightningModule, - outputs: list[torch.Tensor], - batch: torch.Tensor, - batch_idx: int, - epoch: int, - **kwargs, - ) -> None: - """Plotting function to be implemented by subclasses.""" - - @rank_zero_only - def on_validation_batch_end( - self, - trainer, - pl_module, - output, + output: list[torch.Tensor], batch: torch.Tensor, batch_idx: int, **kwargs, @@ -310,12 +314,12 @@ def __init__( @rank_zero_only def _plot( self, - trainer, + trainer: pl.Trainer, pl_module: pl.LightningModule, output: list[torch.Tensor], batch: torch.Tensor, - batch_idx, - epoch, + batch_idx: int, + epoch: int, ) -> None: _ = output @@ -406,9 +410,9 @@ def _plot( @rank_zero_only def on_validation_batch_end( self, - trainer, - pl_module, - output, + trainer: pl.Trainer, + pl_module: pl.LightningModule, + output: list[torch.Tensor], batch: torch.Tensor, batch_idx: int, ) -> None: @@ -454,7 +458,7 @@ def _plot( _ = epoch model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model - fig = plot_graph_node_features(model) + fig = plot_graph_node_features(model, datashader=self.datashader_plotting) self._output_figure( trainer.logger, @@ -750,6 +754,7 @@ def _plot( data[0, ...].squeeze(), data[rollout_step + 1, ...].squeeze(), output_tensor[rollout_step, ...], + datashader=self.datashader_plotting, precip_and_related_fields=self.precip_and_related_fields, ) @@ -839,7 +844,7 @@ def _plot( self, trainer: pl.Trainer, pl_module: pl.LightningModule, - outputs: list, + outputs: list[torch.Tensor], batch: torch.Tensor, batch_idx: int, epoch: int, @@ -921,7 +926,7 @@ def _plot( self, trainer: pl.Trainer, pl_module: pl.LightningModule, - outputs: list, + outputs: list[torch.Tensor], batch: torch.Tensor, batch_idx: int, epoch: int, diff --git a/src/anemoi/training/diagnostics/maps.py b/src/anemoi/training/diagnostics/maps.py index 338a9059..fcf88921 100644 --- a/src/anemoi/training/diagnostics/maps.py +++ b/src/anemoi/training/diagnostics/maps.py @@ -32,7 +32,7 @@ def __init__(self) -> None: def __call__(self, lon: np.ndarray, lat: np.ndarray) -> tuple[np.ndarray, np.ndarray]: lon_rad = np.radians(lon) lat_rad = np.radians(lat) - x = [v - 2 * np.pi if v > np.pi else v for v in lon_rad] + x = np.array([v - 2 * np.pi if v > np.pi else v for v in lon_rad], dtype=lon_rad.dtype) y = lat_rad return x, y diff --git a/src/anemoi/training/diagnostics/plots.py b/src/anemoi/training/diagnostics/plots.py index dde80018..d397f05c 100644 --- a/src/anemoi/training/diagnostics/plots.py +++ b/src/anemoi/training/diagnostics/plots.py @@ -13,10 +13,13 @@ import logging from typing import TYPE_CHECKING +import datashader as dsh import matplotlib.pyplot as plt import matplotlib.style as mplstyle import numpy as np +import pandas as pd from anemoi.models.layers.mapper import GraphEdgeMixin +from datashader.mpl_ext import dsshow from matplotlib.collections import LineCollection from matplotlib.colors import BoundaryNorm from matplotlib.colors import ListedColormap @@ -37,6 +40,7 @@ LOGGER = logging.getLogger(__name__) continents = Coastlines() +LAYOUT = "tight" @dataclass @@ -105,7 +109,7 @@ def plot_loss( # create plot # more space for legend figsize = (8, 3) if legend_patches else (4, 3) - fig, ax = plt.subplots(1, 1, figsize=figsize) + fig, ax = plt.subplots(1, 1, figsize=figsize, layout=LAYOUT) # histogram plot ax.bar(np.arange(x.size), x, color=colors, log=1) @@ -114,8 +118,7 @@ def plot_loss( ax.set_xticks(list(xticks.values()), list(xticks.keys()), rotation=60) if legend_patches: # legend outside and to the right of the plot - plt.legend(handles=legend_patches, bbox_to_anchor=(1.01, 1), loc="upper left") - plt.tight_layout() + ax.legend(handles=legend_patches, bbox_to_anchor=(1.01, 1), loc="upper left") return fig @@ -154,7 +157,7 @@ def plot_power_spectrum( n_plots_x, n_plots_y = len(parameters), 1 figsize = (n_plots_y * 4, n_plots_x * 3) - fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize) + fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize, layout=LAYOUT) pc = EquirectangularProjection() lat, lon = latlons[:, 0], latlons[:, 1] @@ -217,7 +220,6 @@ def plot_power_spectrum( ax[plot_idx].set_xlabel("$k$") ax[plot_idx].set_ylabel("$P(k)$") ax[plot_idx].set_aspect("auto", adjustable=None) - fig.tight_layout() return fig @@ -285,7 +287,7 @@ def plot_histogram( n_plots_x, n_plots_y = len(parameters), 1 figsize = (n_plots_y * 4, n_plots_x * 3) - fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize) + fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize, layout=LAYOUT) for plot_idx, (variable_idx, (variable_name, output_only)) in enumerate(parameters.items()): yt = y_true[..., variable_idx].squeeze() @@ -325,7 +327,6 @@ def plot_histogram( ax[plot_idx].legend() ax[plot_idx].set_aspect("auto", adjustable=None) - fig.tight_layout() return fig @@ -338,6 +339,7 @@ def plot_predicted_multilevel_flat_sample( x: np.ndarray, y_true: np.ndarray, y_pred: np.ndarray, + datashader: bool = False, precip_and_related_fields: list | None = None, ) -> Figure: """Plots data for one multilevel latlon-"flat" sample. @@ -363,6 +365,8 @@ def plot_predicted_multilevel_flat_sample( Expected data of shape (lat*lon, nvar*level) y_pred : np.ndarray Predicted data of shape (lat*lon, nvar*level) + datashader: bool, optional + Scatter plot, by default False precip_and_related_fields : list, optional List of precipitation-like variables, by default [] @@ -375,7 +379,7 @@ def plot_predicted_multilevel_flat_sample( n_plots_x, n_plots_y = len(parameters), n_plots_per_sample figsize = (n_plots_y * 4, n_plots_x * 3) - fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize) + fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize, layout=LAYOUT) pc = EquirectangularProjection() lat, lon = latlons[:, 0], latlons[:, 1] @@ -397,6 +401,7 @@ def plot_predicted_multilevel_flat_sample( variable_name, clevels, cmap_precip, + datashader, precip_and_related_fields, ) else: @@ -411,6 +416,7 @@ def plot_predicted_multilevel_flat_sample( variable_name, clevels, cmap_precip, + datashader, precip_and_related_fields, ) @@ -428,6 +434,7 @@ def plot_flat_sample( vname: str, clevels: float, cmap_precip: str, + datashader: bool = False, precip_and_related_fields: list | None = None, ) -> None: """Plot a "flat" 1D sample. @@ -436,7 +443,7 @@ def plot_flat_sample( Parameters ---------- - fig : _type_ + fig : Figure Figure object handle ax : matplotlib.axes Axis object handle @@ -456,9 +463,14 @@ def plot_flat_sample( Accumulation levels used for precipitation related plots cmap_precip: str Colors used for each accumulation level + datashader: bool, optional + Datashader plott, by default True precip_and_related_fields : list, optional List of precipitation-like variables, by default [] + Returns + ------- + None """ precip_and_related_fields = precip_and_related_fields or [] if vname in precip_and_related_fields: @@ -473,17 +485,38 @@ def plot_flat_sample( # converting to mm from m truth *= 1000.0 pred *= 1000.0 - scatter_plot(fig, ax[1], lon=lon, lat=lat, data=truth, cmap=precip_colormap, norm=norm, title=f"{vname} target") - scatter_plot(fig, ax[2], lon=lon, lat=lat, data=pred, cmap=precip_colormap, norm=norm, title=f"{vname} pred") - scatter_plot( + single_plot( + fig, + ax[1], + lon, + lat, + truth, + cmap=precip_colormap, + norm=norm, + title=f"{vname} target", + datashader=datashader, + ) + single_plot( + fig, + ax[2], + lon, + lat, + pred, + cmap=precip_colormap, + norm=norm, + title=f"{vname} pred", + datashader=datashader, + ) + single_plot( fig, ax[3], - lon=lon, - lat=lat, - data=truth - pred, + lon, + lat, + truth - pred, cmap="bwr", norm=TwoSlopeNorm(vcenter=0.0), title=f"{vname} pred err", + datashader=datashader, ) elif vname == "mwd": cyclic_colormap = "twilight" @@ -495,10 +528,28 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: sample_shape = truth.shape pred = np.maximum(np.zeros(sample_shape), np.minimum(360 * np.ones(sample_shape), (pred))) - scatter_plot(fig, ax[1], lon=lon, lat=lat, data=truth, cmap=cyclic_colormap, title=f"{vname} target") - scatter_plot(fig, ax[2], lon=lon, lat=lat, data=pred, cmap=cyclic_colormap, title=f"capped {vname} pred") + single_plot( + fig, + ax[1], + lon=lon, + lat=lat, + data=truth, + cmap=cyclic_colormap, + title=f"{vname} target", + datashader=datashader, + ) + single_plot( + fig, + ax[2], + lon=lon, + lat=lat, + data=pred, + cmap=cyclic_colormap, + title=f"capped {vname} pred", + datashader=datashader, + ) err_plot = error_plot_in_degrees(truth, pred) - scatter_plot( + single_plot( fig, ax[3], lon=lon, @@ -507,26 +558,37 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: cmap="bwr", norm=TwoSlopeNorm(vcenter=0.0), title=f"{vname} pred err: {np.nanmean(np.abs(err_plot)):.{4}f} deg.", + datashader=datashader, ) else: - scatter_plot(fig, ax[1], lon=lon, lat=lat, data=truth, title=f"{vname} target") - scatter_plot(fig, ax[2], lon=lon, lat=lat, data=pred, title=f"{vname} pred") - scatter_plot( + single_plot(fig, ax[1], lon, lat, truth, title=f"{vname} target", datashader=datashader) + single_plot(fig, ax[2], lon, lat, pred, title=f"{vname} pred", datashader=datashader) + single_plot( fig, ax[3], - lon=lon, - lat=lat, - data=truth - pred, + lon, + lat, + truth - pred, cmap="bwr", norm=TwoSlopeNorm(vcenter=0.0), title=f"{vname} pred err", + datashader=datashader, ) if sum(input_) != 0: if vname == "mwd": - scatter_plot(fig, ax[0], lon=lon, lat=lat, data=input_, cmap=cyclic_colormap, title=f"{vname} input") + single_plot( + fig, + ax[0], + lon=lon, + lat=lat, + data=input_, + cmap=cyclic_colormap, + title=f"{vname} input", + datashader=datashader, + ) err_plot = error_plot_in_degrees(pred, input_) - scatter_plot( + single_plot( fig, ax[4], lon=lon, @@ -535,9 +597,10 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: cmap="bwr", norm=TwoSlopeNorm(vcenter=0.0), title=f"{vname} increment [pred - input] % 360", + datashader=datashader, ) err_plot = error_plot_in_degrees(truth, input_) - scatter_plot( + single_plot( fig, ax[5], lon=lon, @@ -546,28 +609,31 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: cmap="bwr", norm=TwoSlopeNorm(vcenter=0.0), title=f"{vname} persist err: {np.nanmean(np.abs(err_plot)):.{4}f} deg.", + datashader=datashader, ) else: - scatter_plot(fig, ax[0], lon=lon, lat=lat, data=input_, title=f"{vname} input") - scatter_plot( + single_plot(fig, ax[0], lon, lat, input_, title=f"{vname} input", datashader=datashader) + single_plot( fig, ax[4], - lon=lon, - lat=lat, - data=pred - input_, + lon, + lat, + pred - input_, cmap="bwr", norm=TwoSlopeNorm(vcenter=0.0), title=f"{vname} increment [pred - input]", + datashader=datashader, ) - scatter_plot( + single_plot( fig, ax[5], - lon=lon, - lat=lat, - data=truth - input_, + lon, + lat, + truth - input_, cmap="bwr", norm=TwoSlopeNorm(vcenter=0.0), title=f"{vname} persist err", + datashader=datashader, ) else: ax[0].axis("off") @@ -575,18 +641,21 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray: ax[5].axis("off") -def scatter_plot( +def single_plot( fig: Figure, - ax: plt.Axes, - *, + ax: plt.axes, lon: np.array, lat: np.array, data: np.array, cmap: str = "viridis", norm: str | None = None, title: str | None = None, + datashader: bool = False, ) -> None: - """Lat-lon scatter plot: can work with arbitrary grids. + """Plot a single lat-lon map. + + Plotting can be made either using datashader plot or Datashader(bin) plots. + By default it uses Datashader since it is faster and more efficient. Parameters ---------- @@ -598,7 +667,7 @@ def scatter_plot( longitude coordinates array, shape (lon,) lat : np.ndarray latitude coordinates array, shape (lat,) - data : _type_ + data : np.ndarray Data to plot cmap : str, optional Colormap string from matplotlib, by default "viridis" @@ -606,18 +675,42 @@ def scatter_plot( Normalization string from matplotlib, by default None title : str, optional Title for plot, by default None + datashader: bool, optional + Scatter plot, by default False + Returns + ------- + None """ - psc = ax.scatter( - lon, - lat, - c=data, - cmap=cmap, - s=1, - alpha=1.0, - norm=norm, - rasterized=True, - ) + if not datashader: + psc = ax.scatter( + lon, + lat, + c=data, + cmap=cmap, + s=1, + alpha=1.0, + norm=norm, + rasterized=False, + ) + else: + df = pd.DataFrame({"val": data, "x": lon, "y": lat}) + # Adjust binning to match the resolution of the data + n_pixels = int(np.floor(data.shape[0] / 212)) + psc = dsshow( + df, + dsh.Point("x", "y"), + dsh.mean("val"), + vmin=data.min(), + vmax=data.max(), + cmap=cmap, + plot_width=n_pixels, + plot_height=n_pixels, + norm=norm, + aspect="auto", + ax=ax, + ) + ax.set_xlim((-np.pi, np.pi)) ax.set_ylim((-np.pi / 2, np.pi / 2)) @@ -644,9 +737,9 @@ def edge_plot( Parameters ---------- - fig : _type_ + fig : Figure Figure object handle - ax : _type_ + ax : matplotlib.axes Axis object handle src_coords : np.ndarray of shape (num_edges, 2) Source latitudes and longitudes. @@ -680,13 +773,15 @@ def edge_plot( fig.colorbar(psc, ax=ax) -def plot_graph_node_features(model: nn.Module) -> Figure: +def plot_graph_node_features(model: nn.Module, datashader: bool = False) -> Figure: """Plot trainable graph node features. Parameters ---------- model: AneomiModelEncProcDec Model object + datashader: bool, optional + Scatter plot, by default False Returns ------- @@ -696,7 +791,7 @@ def plot_graph_node_features(model: nn.Module) -> Figure: nrows = len(nodes_name := model._graph_data.node_types) ncols = min(model.node_attributes.trainable_tensors[m].trainable.shape[1] for m in nodes_name) figsize = (ncols * 4, nrows * 3) - fig, ax = plt.subplots(nrows, ncols, figsize=figsize) + fig, ax = plt.subplots(nrows, ncols, figsize=figsize, layout=LAYOUT) for row, (mesh, trainable_tensor) in enumerate(model.node_attributes.trainable_tensors.items()): latlons = model.node_attributes.get_coordinates(mesh).cpu().numpy() @@ -706,13 +801,14 @@ def plot_graph_node_features(model: nn.Module) -> Figure: for i in range(ncols): ax_ = ax[row, i] if ncols > 1 else ax[row] - scatter_plot( + single_plot( fig, ax_, lon=lon, lat=lat, data=node_features[..., i], title=f"{mesh} trainable feature #{i + 1}", + datashader=datashader, ) return fig @@ -744,7 +840,7 @@ def plot_graph_edge_features(model: nn.Module, q_extreme_limit: float = 0.05) -> ncols = min(module.trainable.trainable.shape[1] for module in trainable_modules.values()) nrows = len(trainable_modules) figsize = (ncols * 4, nrows * 3) - fig, ax = plt.subplots(nrows, ncols, figsize=figsize) + fig, ax = plt.subplots(nrows, ncols, figsize=figsize, layout=LAYOUT) for row, ((src, dst), graph_mapper) in enumerate(trainable_modules.items()): src_coords = model.node_attributes.get_coordinates(src).cpu().numpy()