Skip to content

Commit

Permalink
Fix/async callbacks (ecmwf#102)
Browse files Browse the repository at this point in the history
* Refactor Callbacks
- Split into seperate files
- Use list in config to add callbacks
- Provide legacy config enabled approach
- Fix ruff issues

* Update changelog

* Fix TypeError

* Move to hydra.instantiate

* Add __all__

* Add to base config

* Fix nested list

* Fix nested get issue

* Fix type checking

* feat: edge plot in callbacks

* feat: set default extra callbacks

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* fix: typing & refactoring

* fix: remove list comprehension

* Refactor according to PR
- Prefill config with callbacks
- Warn on deprecations for old config
- Expand config enabled
- Add back SWA
- Fix logging callback
- Add flag to disable checkpointing
- Add testing

* Update deprecation warning

* Refactor: Remove backwards compatability,
- Split plots
- Rename, lr to optimiser
- Refactor plotting callbacks to be more init config

* Fix tests

* PR Fixes
- Remove enabled from plotting callbacks
- Connect sample_idx in config

* Update Changelog

* Refactor rollout (ecmwf#87)

Refactor rollout logic

* Remove batch frequency from LongRolloutPlots

* Remove TP reference

* Remove missing config reference

* Authentication support for mlflow sync (ecmwf#51)

* feat: authentication support for mlflow sync

* chore: formatting

* chore: changelog

* chore: changelog add link

* fix: sync authentication flag

* refactor: move `health_check` to submodule top level

* feat: add health check

* chore: update error msg

* refactor: mlflow utils

* New mlflow authentication API (ecmwf#78)

* fix: mlflow auth use web seed token

* feat: make target env var an optional argument

* chore: docstrings

* fix: tests

* chore: add comment

* chore: changelog

* chore: docstring

* Update changelog

* rebase

* Update deprecation warning

* Refactor: Remove backwards compatability,
- Split plots
- Rename, lr to optimiser
- Refactor plotting callbacks to be more init config

* add scatter plot

* adding async

* fix

* tests

* fix failing tests

* rm change to ds valid

* precommit hooks

* fix linting

* rebase

* Update deprecation warning

* Refactor: Remove backwards compatability,
- Split plots
- Rename, lr to optimiser
- Refactor plotting callbacks to be more init config

* add scatter plot

* adding async

* fix

* tests

* fix failing tests

* rm change to ds valid

* precommit hooks

* fix linting

* revert unnecessary config changes

* change config files

* Swapped histogram and spectrum

* Update copyright notice

* Fix issues with split of PlotAdditionalMetrics

* Fix CHANGELOG

* Fix documentation for callbacks

* Add all callback submodules to docs

* Apply suggestions from code review

Co-authored-by: Sara Hahner <[email protected]>

* Fix init args issue in RolloutPlots

* Add rollout_eval config

* Add training mode to rollout step

* Force LongRolloutPlots to plot in serial

* Add warning to LongRolloutPlots when async

* Fix asserrt calculation

* Apply post_processors before plotting in LongRolloutPlots

* Fix reference to batch

* Fix debug config

* brinding plot for mean wave direction and fixing type hinting

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* add changelog entry

* fixes for async plots to work

* fix pre-commit styling

* improved loop closing and readability

* fixing for pre-commit hooks

* remove commented block

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* address sugestion for args and kwargs and missing type hints

* update flag to datashader rather than scatter

* update configs

* update docs

* update comment for readability

* update branch

* update branch and test

---------

Co-authored-by: Harrison Cook <[email protected]>
Co-authored-by: Mario Santa Cruz <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Gert Mertes <[email protected]>
Co-authored-by: Sara Hahner <[email protected]>
Co-authored-by: anaprietonem <[email protected]>
Co-authored-by: Ana Prieto Nemesio <[email protected]>
  • Loading branch information
8 people authored Nov 15, 2024
1 parent 76d3ef6 commit d0a8866
Show file tree
Hide file tree
Showing 9 changed files with 259 additions and 135 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ Keep it human-readable, your future self will thank you!
- Save entire config in mlflow
### Added
- Included more loss functions and allowed configuration [#70](https://github.com/ecmwf/anemoi-training/pull/70)
- Include option to use datashader and optimised asyncronohous callbacks [#102](https://github.com/ecmwf/anemoi-training/pull/102)
- Fix that applies the metric_ranges in the post-processed variable space [#116](https://github.com/ecmwf/anemoi-training/pull/116)
- Allow updates to scalars [#137](https://github.com/ecmwf/anemoi-training/pulls/137)
- Add without subsetting in ScaleTensor
Expand Down
22 changes: 21 additions & 1 deletion docs/modules/diagnostics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,32 @@ parameters to plot, as well as the plotting frequency, and
asynchronosity.

Setting ``config.diagnostics.plot.asynchronous``, means that the model
training doesn't stop whilst the callbacks are being evaluated)
training doesn't stop whilst the callbacks are being evaluated. This is
useful for large models where the plotting can take a long time. The
plotting module uses asynchronous callbacks via `asyncio` and
`concurrent.futures.ThreadPoolExecutor` to handle plotting tasks without
blocking the main application. A dedicated event loop runs in a separate
background thread, allowing plotting tasks to be offloaded to worker
threads. This setup keeps the main thread responsive, handling
plot-related tasks asynchronously and efficiently in the background.

There is an additional flag in the plotting callbacks to control the
rendering method for geospatial plots, offering a trade-off between
performance and detail. When `datashader` is set to True, Datashader is
used for rendering, which accelerates plotting through efficient
hexbining, particularly useful for large datasets. This approach can
produce smoother-looking plots due to the aggregation of data points. If
`datashader` is set to False, matplotlib.scatter is used, which provides
sharper and more detailed visuals but may be slower for large datasets.

**Note** - this asynchronous behaviour is only available for the
plotting callbacks.

.. code:: yaml
plot:
asynchronous: True # Whether to plot asynchronously
datashader: True # Whether to use datashader for plotting (faster)
frequency: # Frequency of the plotting
batch: 750
epoch: 5
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ dependencies = [
"anemoi-graphs>=0.4",
"anemoi-models>=0.3",
"anemoi-utils[provenance]>=0.4.4",
"datashader>=0.16.3",
"einops>=0.6.1",
"hydra-core>=1.3",
"matplotlib>=3.7.1",
Expand Down
1 change: 1 addition & 0 deletions src/anemoi/training/config/diagnostics/plot/detailed.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
asynchronous: True # Whether to plot asynchronously
datashader: True # Choose which technique to use for plotting
frequency: # Frequency of the plotting
batch: 750
epoch: 5
Expand Down
1 change: 1 addition & 0 deletions src/anemoi/training/config/diagnostics/plot/simple.yaml
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
asynchronous: True # Whether to plot asynchronously
datashader: True # Choose which technique to use for plotting
frequency: # Frequency of the plotting
batch: 750
epoch: 10
Expand Down
3 changes: 1 addition & 2 deletions src/anemoi/training/data/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,7 @@ def ds_train(self) -> NativeGridDataset:

@cached_property
def ds_valid(self) -> NativeGridDataset:
r = self.rollout
r = max(r, self.config.dataloader.get("validation_rollout", 1))
r = max(self.rollout, self.config.dataloader.get("validation_rollout", 1))

assert self.config.dataloader.training.end < self.config.dataloader.validation.start, (
f"Training end date {self.config.dataloader.training.end} is not before"
Expand Down
153 changes: 79 additions & 74 deletions src/anemoi/training/diagnostics/callbacks/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,13 @@
# granted to it by virtue of its status as an intergovernmental organisation
# nor does it submit to any jurisdiction.

# ruff: noqa: ANN001

from __future__ import annotations

import asyncio
import copy
import logging
import sys
import threading
import time
import traceback
from abc import ABC
Expand All @@ -23,8 +23,6 @@
from functools import cached_property
from pathlib import Path
from typing import TYPE_CHECKING
from typing import Any
from typing import Callable

import matplotlib.patches as mpatches
import matplotlib.pyplot as plt
Expand All @@ -43,33 +41,14 @@
from anemoi.training.losses.weightedloss import BaseWeightedLoss

if TYPE_CHECKING:
from typing import Any

import pytorch_lightning as pl
from omegaconf import OmegaConf

LOGGER = logging.getLogger(__name__)


class ParallelExecutor(ThreadPoolExecutor):
"""Wraps parallel execution and provides accurate information about errors.
Extends ThreadPoolExecutor to preserve the original traceback and line number.
Reference: https://stackoverflow.com/questions/19309514/getting-original-line-
number-for-exception-in-concurrent-futures/24457608#24457608
"""

def submit(self, fn: Any, *args, **kwargs) -> Callable:
"""Submits the wrapped function instead of `fn`."""
return super().submit(self._function_wrapper, fn, *args, **kwargs)

def _function_wrapper(self, fn: Any, *args: list, **kwargs: dict) -> Callable:
"""Wraps `fn` in order to preserve the traceback of any kind of."""
try:
return fn(*args, **kwargs)
except Exception as exc:
raise sys.exc_info()[0](traceback.format_exc()) from exc


class BasePlotCallback(Callback, ABC):
"""Factory for creating a callback that plots data to Experiment Logging."""

Expand All @@ -93,11 +72,21 @@ def __init__(self, config: OmegaConf) -> None:

self.plot = self._plot
self._executor = None
self._error: BaseException = None
self.datashader_plotting = config.diagnostics.plot.datashader

if self.config.diagnostics.plot.asynchronous:
self._executor = ParallelExecutor(max_workers=1)
self._error: BaseException | None = None
LOGGER.info("Setting up asynchronous plotting ...")
self.plot = self._async_plot
self._executor = ThreadPoolExecutor(max_workers=1)
self.loop_thread = threading.Thread(target=self.start_event_loop, daemon=True)
self.loop_thread.start()

def start_event_loop(self) -> None:
"""Start the event loop in a separate thread."""
self.loop = asyncio.new_event_loop()
asyncio.set_event_loop(self.loop)
self.loop.run_forever()

@rank_zero_only
def _output_figure(
Expand All @@ -113,27 +102,48 @@ def _output_figure(
save_path = Path(
self.save_basedir,
"plots",
f"{tag}_epoch{epoch:03d}.png",
f"{tag}_epoch{epoch:03d}.jpg",
)

save_path.parent.mkdir(parents=True, exist_ok=True)
fig.savefig(save_path, dpi=100, bbox_inches="tight")
fig.canvas.draw()
image_array = np.array(fig.canvas.renderer.buffer_rgba())
plt.imsave(save_path, image_array, dpi=100)
if self.config.diagnostics.log.wandb.enabled:
import wandb

logger.experiment.log({exp_log_tag: wandb.Image(fig)})

if self.config.diagnostics.log.mlflow.enabled:
run_id = logger.run_id
logger.experiment.log_artifact(run_id, str(save_path))

plt.close(fig) # cleanup

@rank_zero_only
def _plot_with_error_catching(self, trainer: pl.Trainer, args: Any, kwargs: Any) -> None:
"""To execute the plot function but ensuring we catch any errors."""
try:
self._plot(trainer, *args, **kwargs)
except BaseException:
import os

LOGGER.exception(traceback.format_exc())
os._exit(1) # to force exit when sanity val steps are used

def teardown(self, trainer: pl.Trainer, pl_module: pl.LightningModule, stage: str) -> None:
"""Method is called to close the threads."""
"""Teardown the callback."""
del trainer, pl_module, stage # unused
LOGGER.info("Teardown of the Plot Callback ...")

if self._executor is not None:
self._executor.shutdown(wait=True)
LOGGER.info("waiting and shutting down the executor ...")
self._executor.shutdown(wait=False, cancel_futures=True)

self.loop.call_soon_threadsafe(self.loop.stop)
self.loop_thread.join()
# Step 3: Close the asyncio event loop
self.loop_thread._stop()
self.loop_thread._delete()

def apply_output_mask(self, pl_module: pl.LightningModule, data: torch.Tensor) -> torch.Tensor:
if hasattr(pl_module, "output_mask") and pl_module.output_mask is not None:
Expand All @@ -147,31 +157,39 @@ def _plot(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
*args,
**kwargs,
*args: Any,
**kwargs: Any,
) -> None:
"""Plotting function to be implemented by subclasses."""

# Async function to run the plot function in the background thread
async def submit_plot(self, trainer: pl.Trainer, *args: Any, **kwargs: Any) -> None:
"""Async function or coroutine to schedule the plot function."""
loop = asyncio.get_running_loop()
# run_in_executor doesn't support keyword arguments,
await loop.run_in_executor(
self._executor,
self._plot_with_error_catching,
trainer,
args,
kwargs,
) # because loop.run_in_executor expects positional arguments, not keyword arguments

@rank_zero_only
def _async_plot(
self,
trainer: pl.Trainer,
*args: list,
**kwargs: dict,
*args: Any,
**kwargs: Any,
) -> None:
"""To execute the plot function but ensuring we catch any errors."""
future = self._executor.submit(
self._plot,
trainer,
*args,
**kwargs,
)
# otherwise the error won't be thrown till the validation epoch is finished
try:
future.result()
except Exception:
LOGGER.exception("Critical error occurred in asynchronous plots.")
sys.exit(1)
"""Run the plot function asynchronously.
This is the function that is called by the callback. It schedules the plot
function to run in the background thread. Since we have an event loop running in
the background thread, we need to schedule the plot function to run in that
loop.
"""
asyncio.run_coroutine_threadsafe(self.submit_plot(trainer, *args, **kwargs), self.loop)


class BasePerBatchPlotCallback(BasePlotCallback):
Expand All @@ -192,26 +210,12 @@ def __init__(self, config: OmegaConf, every_n_batches: int | None = None):
super().__init__(config)
self.every_n_batches = every_n_batches or self.config.diagnostics.plot.frequency.batch

@abstractmethod
@rank_zero_only
def _plot(
def on_validation_batch_end(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: list[torch.Tensor],
batch: torch.Tensor,
batch_idx: int,
epoch: int,
**kwargs,
) -> None:
"""Plotting function to be implemented by subclasses."""

@rank_zero_only
def on_validation_batch_end(
self,
trainer,
pl_module,
output,
output: list[torch.Tensor],
batch: torch.Tensor,
batch_idx: int,
**kwargs,
Expand Down Expand Up @@ -310,12 +314,12 @@ def __init__(
@rank_zero_only
def _plot(
self,
trainer,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
output: list[torch.Tensor],
batch: torch.Tensor,
batch_idx,
epoch,
batch_idx: int,
epoch: int,
) -> None:
_ = output

Expand Down Expand Up @@ -406,9 +410,9 @@ def _plot(
@rank_zero_only
def on_validation_batch_end(
self,
trainer,
pl_module,
output,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
output: list[torch.Tensor],
batch: torch.Tensor,
batch_idx: int,
) -> None:
Expand Down Expand Up @@ -454,7 +458,7 @@ def _plot(
_ = epoch
model = pl_module.model.module.model if hasattr(pl_module.model, "module") else pl_module.model.model

fig = plot_graph_node_features(model)
fig = plot_graph_node_features(model, datashader=self.datashader_plotting)

self._output_figure(
trainer.logger,
Expand Down Expand Up @@ -750,6 +754,7 @@ def _plot(
data[0, ...].squeeze(),
data[rollout_step + 1, ...].squeeze(),
output_tensor[rollout_step, ...],
datashader=self.datashader_plotting,
precip_and_related_fields=self.precip_and_related_fields,
)

Expand Down Expand Up @@ -839,7 +844,7 @@ def _plot(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: list,
outputs: list[torch.Tensor],
batch: torch.Tensor,
batch_idx: int,
epoch: int,
Expand Down Expand Up @@ -921,7 +926,7 @@ def _plot(
self,
trainer: pl.Trainer,
pl_module: pl.LightningModule,
outputs: list,
outputs: list[torch.Tensor],
batch: torch.Tensor,
batch_idx: int,
epoch: int,
Expand Down
2 changes: 1 addition & 1 deletion src/anemoi/training/diagnostics/maps.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def __init__(self) -> None:
def __call__(self, lon: np.ndarray, lat: np.ndarray) -> tuple[np.ndarray, np.ndarray]:
lon_rad = np.radians(lon)
lat_rad = np.radians(lat)
x = [v - 2 * np.pi if v > np.pi else v for v in lon_rad]
x = np.array([v - 2 * np.pi if v > np.pi else v for v in lon_rad], dtype=lon_rad.dtype)
y = lat_rad
return x, y

Expand Down
Loading

0 comments on commit d0a8866

Please sign in to comment.