Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 640007576
  • Loading branch information
Forgotten authored and The swirl_dynamics Authors committed Jun 4, 2024
1 parent 8d1615a commit 6244dc1
Show file tree
Hide file tree
Showing 2 changed files with 166 additions and 50 deletions.
164 changes: 141 additions & 23 deletions swirl_dynamics/projects/debiasing/rectified_flow/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@


_ERA5_VARIABLES = {
"temperature": {"level": 1000},
"2m_temperature": None,
"specific_humidity": {"level": 1000},
"geopotential": {"level": [200, 500]},
"mean_sea_level_pressure": None,
}

_ERA5_WIND_COMPONENTS = {
"u_component_of_wind": {"level": 1000},
"v_component_of_wind": {"level": 1000},
"10m_u_component_of_wind": None,
"10m_v_component_of_wind": None,
}

_LENS2_MEMBER_INDEXER = {"member": "cmip6_1001_001"}
Expand Down Expand Up @@ -81,8 +81,9 @@ def main(argv):
# Only 0-th process should write the json file to disk, in order to avoid
# race conditions.
if jax.process_index() == 0:
with tf.io.gfile.GFile(name=osp.join(workdir,
"config.json"), mode="w") as f:
with tf.io.gfile.GFile(
name=osp.join(workdir, "config.json"), mode="w"
) as f:
conf_json = config.to_json_best_effort()
if isinstance(conf_json, str): # Sometimes `.to_json()` returns string
conf_json = json.loads(conf_json)
Expand All @@ -99,20 +100,20 @@ def main(argv):
)

optimizer = optax.chain(
# optax.clip(config.clip),
optax.clip_by_global_norm(config.max_norm),
optax.adam(
learning_rate=schedule,
b1=config.beta1,
),
# optax.ema(decay=0.999)
)

assert (
config.input_shapes[0][-1] == config.input_shapes[1][-1]
and config.input_shapes[0][-1] == config.out_channels
)

# TODO: add an utility function to encapsulate all this code.

if config.tf_grain_hdf5:

train_dataloader = data_utils.UnpairedDataLoader(
Expand Down Expand Up @@ -172,31 +173,35 @@ def main(argv):
seed=config.seed,
batch_size=config.batch_size,
drop_remainder=True,
worker_count=config.num_workers,)
worker_count=config.num_workers,
)

lens2_loader_train = data_utils.create_default_lens2_loader(
date_range=config.data_range_train,
shuffle=config.shuffle,
seed=config.seed,
batch_size=config.batch_size,
drop_remainder=True,
worker_count=config.num_workers,)
worker_count=config.num_workers,
)

era5_loader_eval = data_utils.create_default_era5_loader(
date_range=config.data_range_eval,
shuffle=config.shuffle,
seed=config.seed,
batch_size=config.batch_size_eval,
drop_remainder=True,
worker_count=config.num_workers,)
worker_count=config.num_workers,
)

lens2_loader_eval = data_utils.create_default_lens2_loader(
date_range=config.data_range_eval,
shuffle=config.shuffle,
seed=config.seed,
batch_size=config.batch_size_eval,
drop_remainder=True,
worker_count=config.num_workers,)
worker_count=config.num_workers,
)

elif "dummy_loaders" in config and config.dummy_loaders:
# Dummy data.
Expand All @@ -216,6 +221,67 @@ def main(argv):
fake_batch_lens2
)

elif "chunked_loaders" in config and config.chunked_loaders:
logging.info("Using chunked loaders.")
era5_loader_train = data_utils.create_chunked_era5_loader(
date_range=config.data_range_train,
shuffle=config.shuffle,
seed=config.seed,
batch_size=config.batch_size,
num_chunks=config.num_chunks,
drop_remainder=True,
worker_count=config.num_workers,
)

lens2_loader_train = data_utils.create_chunked_lens2_loader(
date_range=config.data_range_train,
shuffle=config.shuffle,
seed=config.seed,
batch_size=config.batch_size,
num_chunks=config.num_chunks,
drop_remainder=True,
worker_count=config.num_workers,
)

era5_loader_eval = data_utils.create_chunked_era5_loader(
date_range=config.data_range_eval,
shuffle=config.shuffle,
seed=config.seed,
batch_size=config.batch_size_eval,
num_chunks=config.num_chunks,
drop_remainder=True,
worker_count=config.num_workers,
)

lens2_loader_eval = data_utils.create_chunked_lens2_loader(
date_range=config.data_range_eval,
shuffle=config.shuffle,
seed=config.seed,
batch_size=config.batch_size_eval,
num_chunks=config.num_chunks,
drop_remainder=True,
worker_count=config.num_workers,
)

elif "chunked_aligned_loader" and config.chunked_aligned_loader:
logging.info("Using chunked aligned loaders.")

lens2_era5_loader_train = data_utils.create_lens2_era5_loader_chunked(
date_range=config.data_range_train,
shuffle=config.shuffle,
seed=config.seed,
batch_size=config.batch_size_eval,
drop_remainder=True,
worker_count=config.num_workers,
)
lens2_era5_loader_eval = data_utils.create_lens2_era5_loader_chunked(
date_range=config.data_range_eval,
shuffle=config.shuffle,
seed=config.seed,
batch_size=config.batch_size_eval,
drop_remainder=True,
worker_count=config.num_workers,
)
else:
era5_loader_train = data_utils.create_era5_loader(
date_range=config.data_range_train,
Expand All @@ -225,7 +291,8 @@ def main(argv):
seed=config.seed,
batch_size=config.batch_size,
drop_remainder=True,
worker_count=config.num_workers,)
worker_count=config.num_workers,
)

lens2_loader_train = data_utils.create_lens2_loader(
date_range=config.data_range_train,
Expand All @@ -235,7 +302,8 @@ def main(argv):
variable_names=lens2_variable_names,
batch_size=config.batch_size,
drop_remainder=True,
worker_count=config.num_workers,)
worker_count=config.num_workers,
)

era5_loader_eval = data_utils.create_era5_loader(
date_range=config.data_range_eval,
Expand All @@ -245,7 +313,8 @@ def main(argv):
variables=era5_variables,
wind_components=era5_wind_components,
drop_remainder=True,
worker_count=config.num_workers,)
worker_count=config.num_workers,
)

lens2_loader_eval = data_utils.create_lens2_loader(
date_range=config.data_range_eval,
Expand All @@ -255,16 +324,65 @@ def main(argv):
member_indexer=lens2_member_indexer,
variable_names=lens2_variable_names,
drop_remainder=True,
worker_count=config.num_workers,)
worker_count=config.num_workers,
)

# Then create the mixed dataloaders here.
train_dataloader = data_utils.DualLens2Era5Dataset(
era5_loader=era5_loader_train, lens2_loader=lens2_loader_train
)
# Creating the DataLoaders.
if "chunked_loaders" in config and config.chunked_loaders:

# Then create the mixed dataloaders here.
train_dataloader = data_utils.DualChunkedLens2Era5Dataset(
era5_loader=era5_loader_train, lens2_loader=lens2_loader_train # pylint: disable=undefined-variable
)

eval_dataloader = data_utils.DualChunkedLens2Era5Dataset(
era5_loader=era5_loader_eval, lens2_loader=lens2_loader_eval # pylint: disable=undefined-variable
)

elif "date_aligned" in config and config.date_aligned:
logging.info("Using date aligned loaders.")
train_dataloader = data_utils.create_lens2_era5_loader(
date_range=config.data_range_train,
output_variables=era5_variables,
output_wind_components=era5_wind_components,
input_member_indexer=lens2_member_indexer,
input_variable_names=lens2_variable_names,
shuffle=config.shuffle,
seed=config.seed,
batch_size=config.batch_size,
drop_remainder=True,
worker_count=config.num_workers,
)
eval_dataloader = data_utils.create_lens2_era5_loader(
date_range=config.data_range_eval,
output_variables=era5_variables,
output_wind_components=era5_wind_components,
input_member_indexer=lens2_member_indexer,
input_variable_names=lens2_variable_names,
shuffle=config.shuffle,
seed=config.seed,
batch_size=config.batch_size_eval,
drop_remainder=True,
worker_count=config.num_workers,
)
elif "chunked_aligned_loader" and config.chunked_aligned_loader:
logging.info("Using chunked aligned loaders.")
train_dataloader = data_utils.AlignedChunkedLens2Era5Dataset(
loader=lens2_era5_loader_train # pylint: disable=undefined-variable
)
eval_dataloader = data_utils.AlignedChunkedLens2Era5Dataset(
loader=lens2_era5_loader_eval # pylint: disable=undefined-variable
)
else:
# Then create the mixed dataloaders here.
train_dataloader = data_utils.DualLens2Era5Dataset(
era5_loader=era5_loader_train, lens2_loader=lens2_loader_train # pylint: disable=undefined-variable
)

eval_dataloader = data_utils.DualLens2Era5Dataset(
era5_loader=era5_loader_eval, lens2_loader=lens2_loader_eval # pylint: disable=undefined-variable
)

eval_dataloader = data_utils.DualLens2Era5Dataset(
era5_loader=era5_loader_eval, lens2_loader=lens2_loader_eval
)
else: # to avoid the linter to complain.
train_dataloader = None
eval_dataloader = None
Expand Down Expand Up @@ -294,7 +412,7 @@ def main(argv):
), # This must agree with the expected sample shape.
flow_model=flow_model,
min_eval_time_lvl=config.min_time, # This should be close to 0.
max_eval_time_lvl=config.max_time # It should be close to 1.
max_eval_time_lvl=config.max_time, # It should be close to 1.
)

# Defining the trainer.
Expand Down
52 changes: 25 additions & 27 deletions swirl_dynamics/projects/debiasing/rectified_flow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,9 +62,7 @@ class FlowFlaxModule(Protocol):
NOTE: This protocol is for reference only and not statically checked.
"""

def __call__(
self, x: Array, t: Array, is_training: bool
) -> Array:
def __call__(self, x: Array, t: Array, is_training: bool) -> Array:
...


Expand Down Expand Up @@ -93,9 +91,12 @@ class ReFlowModel(models.BaseModel):
jax.random.uniform, dtype=jnp.float32
)

min_train_time: float = 1e-4 # This should be close to 0.
max_train_time: float = 1.0 - 1e-4 # It should be close to 1.

num_eval_cases_per_lvl: int = 1
min_eval_time_lvl: float = 1e-4 # This should be close to 0.
max_eval_time_lvl: float = 1 - 1e-4 # It should be close to 1.
max_eval_time_lvl: float = 1.0 - 1e-4 # It should be close to 1.
num_eval_time_levels: ClassVar[int] = 10

def initialize(self, rng: Array):
Expand Down Expand Up @@ -130,9 +131,12 @@ def loss_fn(
batch_size = len(batch["x_0"])
time_sample_rng, dropout_rng = jax.random.split(rng, num=2)

time_range = self.max_eval_time_lvl - self.min_eval_time_lvl
time = (time_range * self.time_sampling(time_sample_rng, (batch_size,))
+ self.min_eval_time_lvl)
time_range = self.max_train_time - self.min_train_time
# add the normal-logit sampling here.
time = (
time_range * self.time_sampling(time_sample_rng, (batch_size,))
+ self.min_train_time
)

vmap_mult = jax.vmap(jnp.multiply, in_axes=(0, 0))

Expand Down Expand Up @@ -165,23 +169,25 @@ def eval_fn(
Args:
variables: The full model variables for the flow module.
batch: A batch of evaluation data expected to contain two fields `x_0`
and `x_1` fields, representing samples of each set. Both fields are
expected to have shape of `(batch, *spatial_dims, channels)`.
batch: A batch of evaluation data expected to contain two fields `x_0` and
`x_1` fields, representing samples of each set. Both fields are expected
to have shape of `(batch, *spatial_dims, channels)`.
rng: A Jax random key.
Returns:
A dictionary of evaluation metrics.
"""
choice_rng_0, choice_rng_1 = jax.random.split(rng)
# We bootstrap the samples from the batch, but we keep them paired by using
# the same random number generator seed.
choice_rng, _ = jax.random.split(rng)
x_0 = jax.random.choice(
key=choice_rng_0,
key=choice_rng,
a=batch["x_0"],
shape=(self.num_eval_time_levels, self.num_eval_cases_per_lvl),
)

x_1 = jax.random.choice(
key=choice_rng_1,
key=choice_rng,
a=batch["x_1"],
shape=(self.num_eval_time_levels, self.num_eval_cases_per_lvl),
)
Expand All @@ -196,11 +202,9 @@ def eval_fn(
vmap_mult = jax.vmap(jnp.multiply, in_axes=(0, 0))
x_t = vmap_mult(x_1, time_eval) + vmap_mult(x_0, 1 - time_eval)
flow_fn = self.inference_fn(variables, self.flow_model)
v_t = jax.vmap(flow_fn, in_axes=(1, None), out_axes=1)(
x_t, time_eval
)
v_t = jax.vmap(flow_fn, in_axes=(1, None), out_axes=1)(x_t, time_eval)

# Eq. (1) in [1].
# Eq. (1) in [1]. (by default in_axes=0 and out_axes=0 in vmap)
int_losses = jax.vmap(jnp.mean)(jnp.square((x_1 - x_0 - v_t)))
eval_losses = {f"time_lvl{i}": loss for i, loss in enumerate(int_losses)}

Expand All @@ -210,15 +214,11 @@ def eval_fn(
def inference_fn(variables: models.PyTree, flow_model: nn.Module):
"""Returns the inference flow function."""

def _flow(
x: Array, time: float | Array
) -> Array:
def _flow(x: Array, time: float | Array) -> Array:
# This is a wrapper to vectorize time if it is a float.
if not jnp.shape(jnp.asarray(time)):
time *= jnp.ones((x.shape[0],))
return flow_model.apply(
variables, x=x, sigma=time, is_training=False
)
return flow_model.apply(variables, x=x, sigma=time, is_training=False)

return _flow

Expand Down Expand Up @@ -248,7 +248,7 @@ def __call__(
raise ValueError(
f"Number of channels in the input ({x.shape[-1]}) must "
"match the number of channels in the output "
f"{self.out_channel})."
f"{self.out_channels})."
)

if sigma.ndim < 1:
Expand All @@ -262,8 +262,6 @@ def __call__(

time = sigma * self.time_rescale

f_x = super().__call__(
x, time, cond, is_training=is_training
)
f_x = super().__call__(x, time, cond, is_training=is_training)

return f_x

0 comments on commit 6244dc1

Please sign in to comment.