Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 629222401
  • Loading branch information
Forgotten authored and The swirl_dynamics Authors committed Apr 29, 2024
1 parent f4bec78 commit 30f62b7
Showing 1 changed file with 22 additions and 1 deletion.
23 changes: 22 additions & 1 deletion swirl_dynamics/projects/debiasing/rectified_flow/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@

from absl import app
from absl import flags
from absl import logging

import jax
from ml_collections import config_flags
import optax
Expand All @@ -30,6 +32,7 @@
from swirl_dynamics.templates import train
import tensorflow as tf


_ERA5_VARIABLES = {
"temperature": {"level": 1000},
"specific_humidity": {"level": 1000},
Expand All @@ -48,6 +51,7 @@
FLAGS = flags.FLAGS

flags.DEFINE_string("workdir", None, "Directory to store model data.")

config_flags.DEFINE_config_file(
"config",
None,
Expand All @@ -57,6 +61,16 @@


def main(argv):
# Flags --jax_backend_target and --jax_xla_backend are available through JAX.
if FLAGS.jax_backend_target:
logging.info("Using JAX backend target %s", FLAGS.jax_backend_target)
jax_xla_backend = (
"None" if FLAGS.jax_xla_backend is None else FLAGS.jax_xla_backend
)
logging.info("Using JAX XLA backend %s", jax_xla_backend)
logging.info("JAX process: %d / %d", jax.process_index(), jax.process_count())
logging.info("JAX devices: %r", jax.devices())

if len(argv) > 1:
raise app.UsageError("Too many command-line arguments.")
config = FLAGS.config
Expand Down Expand Up @@ -279,10 +293,17 @@ def main(argv):
base_dir=workdir,
options=ckpt_options,
),
# Callback to add the number of iterations/second.
callbacks.ProgressReport(
num_train_steps=config.num_train_steps,
),
# TODO add a plot callback.
),
)


if __name__ == "__main__":
app.run(main)
# Provide access to --jax_backend_target and --jax_xla_backend flags.
jax.config.config_with_absl()
handler = app.run
handler(main)

0 comments on commit 30f62b7

Please sign in to comment.