From 8d1615a0b1277d16098458bf8298d84284efd734 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Leonardo=20Zepeda-N=C3=BA=C3=B1ez?= Date: Mon, 3 Jun 2024 13:58:48 -0700 Subject: [PATCH] Code update PiperOrigin-RevId: 639903056 --- .../projects/debiasing/rectified_flow/main.py | 20 ++++++++++++++++++- 1 file changed, 19 insertions(+), 1 deletion(-) diff --git a/swirl_dynamics/projects/debiasing/rectified_flow/main.py b/swirl_dynamics/projects/debiasing/rectified_flow/main.py index 5251239..cdd189b 100644 --- a/swirl_dynamics/projects/debiasing/rectified_flow/main.py +++ b/swirl_dynamics/projects/debiasing/rectified_flow/main.py @@ -14,13 +14,13 @@ r"""The main entry point for running training loops.""" +import itertools import json from os import path as osp from absl import app from absl import flags from absl import logging - import jax from ml_collections import config_flags import optax @@ -198,6 +198,24 @@ def main(argv): drop_remainder=True, worker_count=config.num_workers,) + elif "dummy_loaders" in config and config.dummy_loaders: + # Dummy data. + fake_batch_lens2 = { + "x_0": jax.numpy.zeros( + (config.batch_size,) + config.input_shapes[0][1:] + ) + } + fake_batch_era5 = { + "x_1": jax.numpy.ones( + (config.batch_size,) + config.input_shapes[1][1:] + ) + } + + era5_loader_train = era5_loader_eval = itertools.repeat(fake_batch_era5) + lens2_loader_train = lens2_loader_eval = itertools.repeat( + fake_batch_lens2 + ) + else: era5_loader_train = data_utils.create_era5_loader( date_range=config.data_range_train,