diff --git a/swirl_dynamics/projects/debiasing/rectified_flow/main.py b/swirl_dynamics/projects/debiasing/rectified_flow/main.py index 51387f2..2976a70 100644 --- a/swirl_dynamics/projects/debiasing/rectified_flow/main.py +++ b/swirl_dynamics/projects/debiasing/rectified_flow/main.py @@ -19,6 +19,8 @@ from absl import app from absl import flags +from absl import logging + import jax from ml_collections import config_flags import optax @@ -30,6 +32,7 @@ from swirl_dynamics.templates import train import tensorflow as tf + _ERA5_VARIABLES = { "temperature": {"level": 1000}, "specific_humidity": {"level": 1000}, @@ -48,6 +51,7 @@ FLAGS = flags.FLAGS flags.DEFINE_string("workdir", None, "Directory to store model data.") + config_flags.DEFINE_config_file( "config", None, @@ -57,6 +61,16 @@ def main(argv): + # Flags --jax_backend_target and --jax_xla_backend are available through JAX. + if FLAGS.jax_backend_target: + logging.info("Using JAX backend target %s", FLAGS.jax_backend_target) + jax_xla_backend = ( + "None" if FLAGS.jax_xla_backend is None else FLAGS.jax_xla_backend + ) + logging.info("Using JAX XLA backend %s", jax_xla_backend) + logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count()) + logging.info("JAX devices: %r", jax.devices()) + if len(argv) > 1: raise app.UsageError("Too many command-line arguments.") config = FLAGS.config @@ -279,10 +293,17 @@ def main(argv): base_dir=workdir, options=ckpt_options, ), + # Callback to add the number of iterations/second. + callbacks.ProgressReport( + num_train_steps=config.num_train_steps, + ), # TODO add a plot callback. ), ) if __name__ == "__main__": - app.run(main) + # Provide access to --jax_backend_target and --jax_xla_backend flags. + jax.config.config_with_absl() + handler = app.run + handler(main)