Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 619569268
  • Loading branch information
zhong1wan authored and The swirl_dynamics Authors committed Mar 27, 2024
1 parent c282a0c commit 5ee133c
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 27 deletions.
2 changes: 2 additions & 0 deletions swirl_dynamics/templates/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@
# pylint: disable=g-importing-member
# pylint: disable=g-multiple-import

from swirl_dynamics.templates import utils
from swirl_dynamics.templates.callbacks import (
Callback,
LogGinConfig,
MatplotlibFigureAsImage,
ParameterOverview,
ProgressReport,
TqdmProgressBar,
Expand Down
132 changes: 106 additions & 26 deletions swirl_dynamics/templates/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,13 @@
import flax
import gin
import jax
import matplotlib.backends.backend_agg as mpl_agg
import matplotlib.pyplot as plt
import numpy as np
import orbax.checkpoint as ocp
from swirl_dynamics.templates import train_states
from swirl_dynamics.templates import trainers
from swirl_dynamics.templates import utils
import tqdm.auto as tqdm

Array = jax.Array
Expand Down Expand Up @@ -153,7 +156,7 @@ def on_eval_batches_end(
self.last_eval_metric = eval_metrics

def on_train_end(self, trainer: Trainer) -> None:
# always save at the end of training
# Always save a checkpoint at the end of training.
if self.ckpt_manager.latest_step() != trainer.train_state.int_step:
self.ckpt_manager.save(
trainer.train_state.int_step,
Expand All @@ -162,6 +165,7 @@ def on_train_end(self, trainer: Trainer) -> None:
)


@utils.primary_process_only
class ProgressReport(Callback):
"""Callback that reports progress during training.
Expand Down Expand Up @@ -252,20 +256,20 @@ def on_train_end(self, trainer: Trainer) -> None:
self.bar.close()


@utils.primary_process_only
class LogGinConfig(Callback):
"""Write gin config string as text in TensorBoard."""

def on_train_begin(self, trainer: trainers.BaseTrainer) -> None:
if jax.process_index() == 0:
config_str = gin.operative_config_str()
self.metric_writer.write_texts(
0, {"config": gin.markdown(config_str), "raw_config_str": config_str}
)
self.metric_writer.flush()
config_str = gin.operative_config_str()
self.metric_writer.write_texts(
0, {"config": gin.markdown(config_str), "raw_config_str": config_str}
)
self.metric_writer.flush()


def _get_markdown_param_table(
params: dict[str, np.ndarray] | Mapping[str, Mapping[str, Any]]
params: dict[str, np.ndarray] | Mapping[str, Mapping[str, Any]],
) -> str:
"""Returns a markdown table of parameters."""
param_table = parameter_overview.get_parameter_overview(
Expand All @@ -281,30 +285,106 @@ def _get_markdown_param_table(
return "\n".join([header, hline] + body + ["", total])


@utils.primary_process_only
@dataclasses.dataclass
class ParameterOverview(Callback):
"""Writes parameter overview to INFO log and/or TensorBoard."""
"""Writes parameter overview to INFO log and/or TensorBoard.
Attributes:
log_to_info: Whether to print parameter overview to log (INFO level).
log_to_tb: Whether to add parameter overview to tensorboard (under text
tab).
"""

log_to_info: bool = True
log_to_tb: bool = True

def on_train_begin(self, trainer: trainers.BaseTrainer) -> None:
if jax.process_index() == 0:
if isinstance(trainer.train_state, train_states.BasicTrainState):
params = trainer.train_state.params
if self.log_to_info:
logging.info("Logging parameter overview.")
parameter_overview.log_parameter_overview(params)

if self.log_to_tb:
self.metric_writer.write_texts(
0,
{"parameters": _get_markdown_param_table(params)},
)
self.metric_writer.flush()
else:
logging.warning(
"ParameterOverview callback: unable extract parameters for"
" overivew."
if isinstance(trainer.train_state, train_states.BasicTrainState):
params = trainer.train_state.params
if self.log_to_info:
logging.info("Logging parameter overview.")
parameter_overview.log_parameter_overview(params)

if self.log_to_tb:
self.metric_writer.write_texts(
0,
{"parameters": _get_markdown_param_table(params)},
)
self.metric_writer.flush()
else:
logging.warning(
"ParameterOverview callback: unable extract parameters for overivew."
)


Figures = Sequence[plt.Figure] | plt.Figure


def figure_to_image(figures: Figures) -> np.ndarray:
"""Converts a sequence of figures to image data ingestable by tensorboard."""

def render_to_rgb(figure: plt.Figure) -> np.ndarray:
canvas = mpl_agg.FigureCanvasAgg(figure)
canvas.draw()
data = np.frombuffer(canvas.buffer_rgba(), dtype=np.uint8)
w, h = figure.canvas.get_width_height()
image_hwc = data.reshape([h, w, 4])[:, :, :3]
plt.close(figure)
return image_hwc

if isinstance(figures, plt.Figure):
figures = [figures]
images = [render_to_rgb(figure) for figure in figures]
return np.stack(images)


@utils.primary_process_only
class MatplotlibFigureAsImage(Callback):
"""Makes matplotlib figures and writes to tensorboard.
Child classes should create model-specific plots in standard callback hooks
and call `.write_images()` to write them to TB. Plot data should be returned
from eval step of the model and declared as `CollectingMetric` in the trainer.
Pattern::
# Model
class Model:
def eval_step(...):
sample = ... # Compute data required for plotting.
return {
"generated_sample": sample,
}
# Trainer
class Trainer:
@flax.struct.dataclass
class EvalMetrics(clu.metrics.Collection):
eval_plot_data: clu.metrics.CollectingMetric.from_outputs(
("generated_sample",) # Matches dict key in model.eval_step output.
)
# Custom plot callback
class PlotSamples(MplFigureAsImage):
def on_eval_batches_end(self, trainer, eval_metrics):
plot_data = eval_metrics["eval_plot_data"] # Same name in trainer cls.
sample = plot_data["generated_sample"] # Key in model.eval_step output.
# make plots
fig, ax = plt.subplots()
ax.imshow(sample[0]) # Plotting first element of aggregated samples.
# write plots to TB
self.write_images(
trainer.train_state.int_step, {"generated_sample": fig}
)
"""

def write_images(self, step: int, images: Mapping[str, Figures]) -> None:
self.metric_writer.write_images(
step, {k: figure_to_image(v) for k, v in images.items()}
)
24 changes: 23 additions & 1 deletion swirl_dynamics/templates/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,11 +15,13 @@
"""Utility functions for the template."""

import collections
from collections.abc import Mapping, Sequence
from collections.abc import Callable, Mapping, Sequence
import functools
import os
from typing import Any

from clu import values
import jax
import jax.numpy as jnp
import numpy as np
import optax
Expand All @@ -28,6 +30,26 @@
Scalar = Any


def primary_process_only(cls: type[Any]) -> type[Any]:
"""Class decorator that modifies all methods to run on primary host only."""

def wrap_method(method: Callable[..., Any]) -> Callable[..., Any]:
@functools.wraps(method)
def wrapper(self, *args, **kwargs):
if jax.process_index() == 0:
return method(self, *args, **kwargs)
else:
return None

return wrapper

for attr_name, attr_value in cls.__dict__.items():
if callable(attr_value) and not attr_name.startswith("__"):
setattr(cls, attr_name, wrap_method(attr_value))

return cls


def load_scalars_from_tfevents(
logdir: str,
) -> Mapping[int, Mapping[str, Scalar]]:
Expand Down

0 comments on commit 5ee133c

Please sign in to comment.