diff --git a/swirl_dynamics/templates/__init__.py b/swirl_dynamics/templates/__init__.py index 6b07110..cb2ca88 100644 --- a/swirl_dynamics/templates/__init__.py +++ b/swirl_dynamics/templates/__init__.py @@ -17,9 +17,11 @@ # pylint: disable=g-importing-member # pylint: disable=g-multiple-import +from swirl_dynamics.templates import utils from swirl_dynamics.templates.callbacks import ( Callback, LogGinConfig, + MatplotlibFigureAsImage, ParameterOverview, ProgressReport, TqdmProgressBar, diff --git a/swirl_dynamics/templates/callbacks.py b/swirl_dynamics/templates/callbacks.py index 1f52a15..aa52e8a 100644 --- a/swirl_dynamics/templates/callbacks.py +++ b/swirl_dynamics/templates/callbacks.py @@ -27,10 +27,13 @@ import flax import gin import jax +import matplotlib.backends.backend_agg as mpl_agg +import matplotlib.pyplot as plt import numpy as np import orbax.checkpoint as ocp from swirl_dynamics.templates import train_states from swirl_dynamics.templates import trainers +from swirl_dynamics.templates import utils import tqdm.auto as tqdm Array = jax.Array @@ -153,7 +156,7 @@ def on_eval_batches_end( self.last_eval_metric = eval_metrics def on_train_end(self, trainer: Trainer) -> None: - # always save at the end of training + # Always save a checkpoint at the end of training. if self.ckpt_manager.latest_step() != trainer.train_state.int_step: self.ckpt_manager.save( trainer.train_state.int_step, @@ -162,6 +165,7 @@ def on_train_end(self, trainer: Trainer) -> None: ) +@utils.primary_process_only class ProgressReport(Callback): """Callback that reports progress during training. @@ -252,20 +256,20 @@ def on_train_end(self, trainer: Trainer) -> None: self.bar.close() +@utils.primary_process_only class LogGinConfig(Callback): """Write gin config string as text in TensorBoard.""" def on_train_begin(self, trainer: trainers.BaseTrainer) -> None: - if jax.process_index() == 0: - config_str = gin.operative_config_str() - self.metric_writer.write_texts( - 0, {"config": gin.markdown(config_str), "raw_config_str": config_str} - ) - self.metric_writer.flush() + config_str = gin.operative_config_str() + self.metric_writer.write_texts( + 0, {"config": gin.markdown(config_str), "raw_config_str": config_str} + ) + self.metric_writer.flush() def _get_markdown_param_table( - params: dict[str, np.ndarray] | Mapping[str, Mapping[str, Any]] + params: dict[str, np.ndarray] | Mapping[str, Mapping[str, Any]], ) -> str: """Returns a markdown table of parameters.""" param_table = parameter_overview.get_parameter_overview( @@ -281,30 +285,106 @@ def _get_markdown_param_table( return "\n".join([header, hline] + body + ["", total]) +@utils.primary_process_only @dataclasses.dataclass class ParameterOverview(Callback): - """Writes parameter overview to INFO log and/or TensorBoard.""" + """Writes parameter overview to INFO log and/or TensorBoard. + + Attributes: + log_to_info: Whether to print parameter overview to log (INFO level). + log_to_tb: Whether to add parameter overview to tensorboard (under text + tab). + """ log_to_info: bool = True log_to_tb: bool = True def on_train_begin(self, trainer: trainers.BaseTrainer) -> None: - if jax.process_index() == 0: - if isinstance(trainer.train_state, train_states.BasicTrainState): - params = trainer.train_state.params - if self.log_to_info: - logging.info("Logging parameter overview.") - parameter_overview.log_parameter_overview(params) - - if self.log_to_tb: - self.metric_writer.write_texts( - 0, - {"parameters": _get_markdown_param_table(params)}, - ) - self.metric_writer.flush() - else: - logging.warning( - "ParameterOverview callback: unable extract parameters for" - " overivew." + if isinstance(trainer.train_state, train_states.BasicTrainState): + params = trainer.train_state.params + if self.log_to_info: + logging.info("Logging parameter overview.") + parameter_overview.log_parameter_overview(params) + + if self.log_to_tb: + self.metric_writer.write_texts( + 0, + {"parameters": _get_markdown_param_table(params)}, ) + self.metric_writer.flush() + else: + logging.warning( + "ParameterOverview callback: unable extract parameters for overivew." + ) + + +Figures = Sequence[plt.Figure] | plt.Figure + + +def figure_to_image(figures: Figures) -> np.ndarray: + """Converts a sequence of figures to image data ingestable by tensorboard.""" + + def render_to_rgb(figure: plt.Figure) -> np.ndarray: + canvas = mpl_agg.FigureCanvasAgg(figure) + canvas.draw() + data = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8) + w, h = figure.canvas.get_width_height() + image_hwc = data.reshape([h, w, 4])[:, :, :3] + plt.close(figure) + return image_hwc + + if isinstance(figures, plt.Figure): + figures = [figures] + images = [render_to_rgb(figure) for figure in figures] + return np.stack(images) + + +@utils.primary_process_only +class MatplotlibFigureAsImage(Callback): + """Makes matplotlib figures and writes to tensorboard. + + Child classes should create model-specific plots in standard callback hooks + and call `.write_images()` to write them to TB. Plot data should be returned + from eval step of the model and declared as `CollectingMetric` in the trainer. + Pattern:: + + # Model + class Model: + + def eval_step(...): + sample = ... # Compute data required for plotting. + return { + "generated_sample": sample, + } + + # Trainer + class Trainer: + + @flax.struct.dataclass + class EvalMetrics(clu.metrics.Collection): + eval_plot_data: clu.metrics.CollectingMetric.from_outputs( + ("generated_sample",) # Matches dict key in model.eval_step output. + ) + + # Custom plot callback + class PlotSamples(MplFigureAsImage): + + def on_eval_batches_end(self, trainer, eval_metrics): + plot_data = eval_metrics["eval_plot_data"] # Same name in trainer cls. + sample = plot_data["generated_sample"] # Key in model.eval_step output. + + # make plots + fig, ax = plt.subplots() + ax.imshow(sample[0]) # Plotting first element of aggregated samples. + + # write plots to TB + self.write_images( + trainer.train_state.int_step, {"generated_sample": fig} + ) + """ + + def write_images(self, step: int, images: Mapping[str, Figures]) -> None: + self.metric_writer.write_images( + step, {k: figure_to_image(v) for k, v in images.items()} + ) diff --git a/swirl_dynamics/templates/utils.py b/swirl_dynamics/templates/utils.py index 5cd80ae..8e3c279 100644 --- a/swirl_dynamics/templates/utils.py +++ b/swirl_dynamics/templates/utils.py @@ -15,11 +15,13 @@ """Utility functions for the template.""" import collections -from collections.abc import Mapping, Sequence +from collections.abc import Callable, Mapping, Sequence +import functools import os from typing import Any from clu import values +import jax import jax.numpy as jnp import numpy as np import optax @@ -28,6 +30,26 @@ Scalar = Any +def primary_process_only(cls: type[Any]) -> type[Any]: + """Class decorator that modifies all methods to run on primary host only.""" + + def wrap_method(method: Callable[..., Any]) -> Callable[..., Any]: + @functools.wraps(method) + def wrapper(self, *args, **kwargs): + if jax.process_index() == 0: + return method(self, *args, **kwargs) + else: + return None + + return wrapper + + for attr_name, attr_value in cls.__dict__.items(): + if callable(attr_value) and not attr_name.startswith("__"): + setattr(cls, attr_name, wrap_method(attr_value)) + + return cls + + def load_scalars_from_tfevents( logdir: str, ) -> Mapping[int, Mapping[str, Scalar]]: