Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 639903056
  • Loading branch information
Forgotten authored and The swirl_dynamics Authors committed Jun 3, 2024
1 parent 3bf566e commit 8d1615a
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion swirl_dynamics/projects/debiasing/rectified_flow/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,13 +14,13 @@

r"""The main entry point for running training loops."""

import itertools
import json
from os import path as osp

from absl import app
from absl import flags
from absl import logging

import jax
from ml_collections import config_flags
import optax
Expand Down Expand Up @@ -198,6 +198,24 @@ def main(argv):
drop_remainder=True,
worker_count=config.num_workers,)

elif "dummy_loaders" in config and config.dummy_loaders:
# Dummy data.
fake_batch_lens2 = {
"x_0": jax.numpy.zeros(
(config.batch_size,) + config.input_shapes[0][1:]
)
}
fake_batch_era5 = {
"x_1": jax.numpy.ones(
(config.batch_size,) + config.input_shapes[1][1:]
)
}

era5_loader_train = era5_loader_eval = itertools.repeat(fake_batch_era5)
lens2_loader_train = lens2_loader_eval = itertools.repeat(
fake_batch_lens2
)

else:
era5_loader_train = data_utils.create_era5_loader(
date_range=config.data_range_train,
Expand Down

0 comments on commit 8d1615a

Please sign in to comment.