Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 614091401
  • Loading branch information
Forgotten authored and The swirl_dynamics Authors committed Mar 9, 2024
1 parent 5421fb1 commit 6f56b65
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 249 deletions.
2 changes: 1 addition & 1 deletion swirl_dynamics/lib/diffusion/unets.py
Original file line number Diff line number Diff line change
Expand Up @@ -529,7 +529,7 @@ def __call__(
name="conv_out",
)(h)

if self.resize_to_shape is not None:
if self.resize_to_shape:
h = layers.FilteredResize(
output_size=input_size, kernel_size=(7, 7), padding=self.padding
)(h)
Expand Down
213 changes: 0 additions & 213 deletions swirl_dynamics/projects/debiasing/rectified_flow/data_utils.py

This file was deleted.

121 changes: 88 additions & 33 deletions swirl_dynamics/projects/debiasing/rectified_flow/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,31 +78,77 @@ def main(argv):
),
)

train_dataloader = data_utils.UnpairedDataLoader(
batch_size=config.batch_size,
dataset_path_a=config.dataset_path_u_lf,
dataset_path_b=config.dataset_path_u_hf,
seed=config.seed,
split="train",
spatial_downsample_factor_a=config.downsample_factor[0],
normalize=config.normalize,
tf_lookup_batch_size=config.tf_lookup_batch_size,
tf_lookup_num_parallel_calls=config.tf_lookup_num_parallel_calls,
tf_interleaved_shuffle=config.tf_interleaved_shuffle,
)

eval_dataloader = data_utils.UnpairedDataLoader(
batch_size=config.batch_size,
dataset_path_a=config.dataset_path_u_lf,
dataset_path_b=config.dataset_path_u_hf,
seed=config.seed,
split="eval",
spatial_downsample_factor_b=config.downsample_factor[1],
normalize=config.normalize,
tf_lookup_batch_size=config.tf_lookup_batch_size,
tf_lookup_num_parallel_calls=config.tf_lookup_num_parallel_calls,
tf_interleaved_shuffle=config.tf_interleaved_shuffle,
)
if config.tf_grain_hdf5:

train_dataloader = data_utils.UnpairedDataLoader(
batch_size=config.batch_size,
dataset_path_a=config.dataset_path_u_lf,
dataset_path_b=config.dataset_path_u_hf,
seed=config.seed,
split="train",
spatial_downsample_factor_a=config.spatial_downsample_factor[0],
normalize=config.normalize,
tf_lookup_batch_size=config.tf_lookup_batch_size,
tf_lookup_num_parallel_calls=config.tf_lookup_num_parallel_calls,
tf_interleaved_shuffle=config.tf_interleaved_shuffle,
)

eval_dataloader = data_utils.UnpairedDataLoader(
batch_size=config.batch_size,
dataset_path_a=config.dataset_path_u_lf,
dataset_path_b=config.dataset_path_u_hf,
seed=config.seed,
split="eval",
spatial_downsample_factor_b=config.spatial_downsample_factor[1],
normalize=config.normalize,
tf_lookup_batch_size=config.tf_lookup_batch_size,
tf_lookup_num_parallel_calls=config.tf_lookup_num_parallel_calls,
tf_interleaved_shuffle=config.tf_interleaved_shuffle,
)
elif config.pygrain_zarr:

era5_loader_train = data_utils.create_era5_loader(
date_range=config.data_range_train,
shuffle=True,
seed=config.seed,
batch_size=config.batch_size,
drop_remainder=True,
worker_count=0,)

lens2_loader_train = data_utils.create_lens2_loader(
date_range=config.data_range_train,
shuffle=True,
seed=config.seed,
batch_size=config.batch_size,
drop_remainder=True,
interp_shapes=config.interp_shapes,
worker_count=0,)

train_dataloader = data_utils.DualLens2Era5Dataset(era5_loader_train,
lens2_loader_train)

era5_loader_eval = data_utils.create_era5_loader(
date_range=config.data_range_eval,
shuffle=True,
seed=config.seed,
batch_size=config.batch_size_eval,
drop_remainder=True,
worker_count=0,)

lens2_loader_eval = data_utils.create_lens2_loader(
date_range=config.data_range_eval,
shuffle=True,
seed=config.seed,
batch_size=config.batch_size_eval,
drop_remainder=True,
interp_shapes=config.interp_shapes,
worker_count=0,)

eval_dataloader = data_utils.DualLens2Era5Dataset(era5_loader_eval,
lens2_loader_eval)
else: # to avoid the linter to complain.
train_dataloader = None
eval_dataloader = None

# Setting up the neural network for the flow model.
flow_model = models.RescaledUnet(
Expand All @@ -114,27 +160,36 @@ def main(argv):
padding=config.padding,
dropout_rate=config.dropout_rate,
use_attention=config.use_attention,
resize_to_shape=config.resize_to_shape,
use_position_encoding=config.use_position_encoding,
num_heads=config.num_heads,
)

model = models.ReFlowModel(
# TODO: clean this part.
input_shape=(
config.input_shapes[0][1] // config.spatial_downsample_factor[0],
config.input_shapes[0][2] // config.spatial_downsample_factor[0],
config.input_shapes[0][1],
config.input_shapes[0][2],
config.input_shapes[0][3],
), # This must agree with the expected sample shape.
flow_model=flow_model,
)

# Defining the trainer.
trainer = trainers.ReFlowTrainer(
model=model,
rng=jax.random.PRNGKey(config.seed),
optimizer=optimizer,
ema_decay=config.ema_decay,
)
if config.distributed:
trainer = trainers.DistributedReFlowTrainer(
model=model,
rng=jax.random.PRNGKey(config.seed),
optimizer=optimizer,
ema_decay=config.ema_decay,
)
else:
trainer = trainers.ReFlowTrainer(
model=model,
rng=jax.random.PRNGKey(config.seed),
optimizer=optimizer,
ema_decay=config.ema_decay,
)

# Setting up checkpointing.
ckpt_options = checkpoint.CheckpointManagerOptions(
Expand Down
4 changes: 2 additions & 2 deletions swirl_dynamics/projects/debiasing/rectified_flow/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
expression that the loss is pursuing.
References:
[1]: Xingchao Liu, Chengyue Gong and Qiang Liu. "Flow Straight and Fast:
[1] Xingchao Liu, Chengyue Gong and Qiang Liu. "Flow Straight and Fast:
Learning to Generate and Transfer Data with Rectified Flow" NeurIPS 2022,
Workshop on Score-Based Methods.
"""
Expand Down Expand Up @@ -89,7 +89,7 @@ class ReFlowModel(models.BaseModel):

input_shape: tuple[int, ...]
flow_model: nn.Module
time_sampling: Callable[[Array, tuple[int]], Array] = functools.partial(
time_sampling: Callable[[Array, tuple[int, ...]], Array] = functools.partial(
jax.random.uniform, dtype=jnp.float32
)

Expand Down

0 comments on commit 6f56b65

Please sign in to comment.