diff --git a/swirl_dynamics/lib/diffusion/vivit.py b/swirl_dynamics/lib/diffusion/vivit.py index 70f8c54..7728838 100644 --- a/swirl_dynamics/lib/diffusion/vivit.py +++ b/swirl_dynamics/lib/diffusion/vivit.py @@ -289,21 +289,19 @@ def __call__(self, x: Array, *, train: bool) -> Array: if self.kernel_init_method == 'central_frame_initializer': kernel_initializer = central_frame_initializer() - # logging.info('Using central frame initializer for input embedding') elif self.kernel_init_method == 'average_frame_initializer': kernel_initializer = average_frame_initializer() - # logging.info('Using average frame initializer for input embedding') else: kernel_initializer = linear.default_kernel_init - # logging.info('Using default initializer for input embedding') x = nn.Conv( - self.embedding_dim, (ft, fh, fw), + features=self.embedding_dim, + kernel_size=(ft, fh, fw), strides=(ft, fh, fw), padding='VALID', name='_conv_3d_embedding', - kernel_init=kernel_initializer)( - x) + kernel_init=kernel_initializer, + )(x) return x diff --git a/swirl_dynamics/projects/spatiotemporal_modeling/configs/kolmogorov.py b/swirl_dynamics/projects/spatiotemporal_modeling/configs/kolmogorov.py new file mode 100644 index 0000000..0a5f290 --- /dev/null +++ b/swirl_dynamics/projects/spatiotemporal_modeling/configs/kolmogorov.py @@ -0,0 +1,94 @@ +# Copyright 2023 The swirl_dynamics Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config file for ViViT Denoiser. + + +""" + +import ml_collections +# pylint: disable=line-too-long +DATA_PATH = '/datasets/hdf5/pde/2d/ns/attractor_spectral_grid_256_spatial_downsample_4_dt_0.001_v0_3_warmup_40.0_t_final_200.0_nu_0.001_n_samples_2000_ntraj_train_256_ntraj_eval_32_ntraj_test_32_drag_0.1_wave_number_4_random_seeds_combined_4.hdf5' +# pylint: enable=line-too-long + + +def get_config(): + """Returns the base experiment configuration.""" + config = ml_collections.ConfigDict() + + # Model. + # TODO(lzepedanunez): undo all the nested dictionaries. + config.model_name = 'ViViT Denoiser' + config.model = ml_collections.ConfigDict() + config.model.hidden_size = 384 # 192 # 768 + config.spatial_downsample_factor = 2 + + config.model.num_heads = 12 + config.model.mlp_dim = 512 + config.model.num_layers = 6 + config.model.dropout_rate = 0.3 + config.model_dtype_str = 'float32' + config.model.noise_embed_dim = 256 + config.model.diffusion_scheme = 'variance_exploding' + + config.save_interval_steps = 1000 + config.max_checkpoints_to_keep = 10 + + # TODO(lzepedanunez): create custom data structures. + config.model.temporal_encoding_config = ml_collections.ConfigDict() + config.model.temporal_encoding_config.method = '3d_conv' + # pylint: disable=line-too-long + config.model.temporal_encoding_config.kernel_init_method = 'central_frame_initializer' + # pylint: enable=line-too-long + config.model.positional_embedding = 'sinusoidal_3d' # 'sinusoidal_3d' + + # TODO(lzepedanunez): patches doesn't need to be a dictionary. + config.model.patches = ml_collections.ConfigDict() + config.model.patches.size = (4, 4, 4) # (time, height, width) + + config.model.attention_config = ml_collections.ConfigDict() + # config.model.attention_config.type = 'factorized_encoder' + config.model.attention_config.type = 'factorized_self_attention_block' + config.model.attention_config.attention_order = 'time_space' + config.model.attention_config.attention_kernel_init_method = 'xavier' + + config.data = ml_collections.ConfigDict() + config.data.file_path_data = DATA_PATH + config.data.num_time_steps = 32 + config.data.time_stride = 1 + config.data.batch_size = 8 + config.data.normalize = True + config.data.random_seed = 1 + config.data.tf_lookup_batch_size = 32 + config.data.std = 1.0 + config.data.space_shape = (64, 64, 1) + + config.optimizer = ml_collections.ConfigDict() + config.optimizer.num_train_steps = 1000000 + config.optimizer.initial_lr = 0.0 + config.optimizer.peak_lr = 3e-4 + config.optimizer.warmup_steps = 50000 + config.optimizer.end_lr = 1e-6 + config.optimizer.ema_decay = 0.999 + config.optimizer.ckpt_interval = 1000 + config.optimizer.max_ckpt_to_keep = 5 + config.optimizer.clip_min = 1e-4 + config.optimizer.metric_aggreration_steps = 50 + config.optimizer.eval_every_steps = 1000 + config.optimizer.num_batches_per_eval = 8 + config.optimizer.clip = 1. + config.optimizer.beta1 = 0.99 + + return config + diff --git a/swirl_dynamics/projects/spatiotemporal_modeling/configs/kolmogorov_med_res.py b/swirl_dynamics/projects/spatiotemporal_modeling/configs/kolmogorov_med_res.py new file mode 100644 index 0000000..a06caff --- /dev/null +++ b/swirl_dynamics/projects/spatiotemporal_modeling/configs/kolmogorov_med_res.py @@ -0,0 +1,94 @@ +# Copyright 2023 The swirl_dynamics Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config file for ViViT Denoiser. + + +""" + +import ml_collections +# pylint: disable=line-too-long +DATA_PATH = '/datasets/hdf5/pde/2d/ns/attractor_spectral_grid_256_spatial_downsample_4_dt_0.001_v0_3_warmup_40.0_t_final_200.0_nu_0.001_n_samples_2000_ntraj_train_256_ntraj_eval_32_ntraj_test_32_drag_0.1_wave_number_4_random_seeds_combined_4.hdf5' +# pylint: enable=line-too-long + + +def get_config(): + """Returns the base experiment configuration.""" + config = ml_collections.ConfigDict() + + # Model. + # TODO(lzepedanunez): Undo all the nested dictionaries. + config.model_name = 'ViViT Denoiser' + config.model = ml_collections.ConfigDict() + config.model.hidden_size = 576 + config.spatial_downsample_factor = 1 + + config.model.num_heads = 18 + config.model.mlp_dim = 512 + config.model.num_layers = 6 + config.model.dropout_rate = 0.3 + config.model_dtype_str = 'float32' + config.model.noise_embed_dim = 256 + config.model.diffusion_scheme = 'variance_exploding' + + config.save_interval_steps = 1000 + config.max_checkpoints_to_keep = 10 + + # TODO(lzepedanunez): create custom data structures. + config.model.temporal_encoding_config = ml_collections.ConfigDict() + config.model.temporal_encoding_config.method = '3d_conv' + # pylint: disable=line-too-long + config.model.temporal_encoding_config.kernel_init_method = 'central_frame_initializer' + # pylint: enable=line-too-long + config.model.positional_embedding = 'sinusoidal_3d' # 'sinusoidal_3d' + + # TODO(lzepedanunez): patches doesn't need to be a dictionary. + config.model.patches = ml_collections.ConfigDict() + config.model.patches.size = (4, 4, 4) # (time, height, width) + + config.model.attention_config = ml_collections.ConfigDict() + # config.model.attention_config.type = 'factorized_encoder' + config.model.attention_config.type = 'factorized_self_attention_block' + config.model.attention_config.attention_order = 'time_space' + config.model.attention_config.attention_kernel_init_method = 'xavier' + + config.data = ml_collections.ConfigDict() + config.data.file_path_data = DATA_PATH + config.data.num_time_steps = 32 + config.data.time_stride = 2 + config.data.batch_size = 8 + config.data.normalize = True + config.data.random_seed = 1 + config.data.tf_lookup_batch_size = 32 + config.data.std = 1.0 + config.data.space_shape = (64, 64, 1) + + config.optimizer = ml_collections.ConfigDict() + config.optimizer.num_train_steps = 1000000 + config.optimizer.initial_lr = 0.0 + config.optimizer.peak_lr = 3e-4 + config.optimizer.warmup_steps = 50000 + config.optimizer.end_lr = 1e-6 + config.optimizer.ema_decay = 0.999 + config.optimizer.ckpt_interval = 1000 + config.optimizer.max_ckpt_to_keep = 5 + config.optimizer.clip_min = 1e-4 + config.optimizer.metric_aggreration_steps = 50 + config.optimizer.eval_every_steps = 1000 + config.optimizer.num_batches_per_eval = 8 + config.optimizer.clip = 1. + config.optimizer.beta1 = 0.99 + + return config + diff --git a/swirl_dynamics/projects/spatiotemporal_modeling/configs/kolmogorov_transformer.py b/swirl_dynamics/projects/spatiotemporal_modeling/configs/kolmogorov_transformer.py new file mode 100644 index 0000000..e781a9b --- /dev/null +++ b/swirl_dynamics/projects/spatiotemporal_modeling/configs/kolmogorov_transformer.py @@ -0,0 +1,96 @@ +# Copyright 2023 The swirl_dynamics Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""Config file for ViViT Denoiser. + + +""" + +import ml_collections +# pylint: disable=line-too-long +DATA_PATH = '/datasets/hdf5/pde/2d/ns/attractor_spectral_grid_256_spatial_downsample_4_dt_0.001_v0_3_warmup_40.0_t_final_200.0_nu_0.001_n_samples_2000_ntraj_train_256_ntraj_eval_32_ntraj_test_32_drag_0.1_wave_number_4_random_seeds_combined_4.hdf5' +# pylint: enable=line-too-long + + +def get_config(): + """Returns the base experiment configuration.""" + config = ml_collections.ConfigDict() + + # Model. + # TODO(lzepedanunez) undo all the nested dictionaries. + config.model_name = 'ViViT Denoiser' + config.model = ml_collections.ConfigDict() + config.model.hidden_size = 192 + config.spatial_downsample_factor = 2 + + config.model.num_heads = 12 + config.model.mlp_dim = 512 + config.model.num_layers = 6 + config.model.dropout_rate = 0.3 + config.model_dtype_str = 'float32' + config.model.noise_embed_dim = 256 + config.model.diffusion_scheme = 'variance_exploding' + + config.save_interval_steps = 1000 + config.max_checkpoints_to_keep = 10 + + # TODO(lzepedanunez): create custom data structures. + config.model.temporal_encoding_config = ml_collections.ConfigDict() + config.model.temporal_encoding_config.method = '3d_conv' + # pylint: disable=line-too-long + # config.model.temporal_encoding_config.kernel_init_method = 'central_frame_initializer' + config.model.temporal_encoding_config.kernel_init_method = 'average_frame_initializer' + # pylint: enable=line-too-long + config.model.positional_embedding = 'none' # 'sinusoidal_3d' + + # TODO(lzepedanunez): patches doesn't need to be a dictionary. + config.model.patches = ml_collections.ConfigDict() + config.model.patches.size = (4, 4, 4) # (time, height, width) + + config.model.attention_config = ml_collections.ConfigDict() + # config.model.attention_config.type = 'factorized_encoder' + # config.model.attention_config.type = 'factorized_self_attention_block' + config.model.attention_config.type = 'spacetime' + config.model.attention_config.attention_order = 'time_space' + config.model.attention_config.attention_kernel_init_method = 'xavier' + + config.data = ml_collections.ConfigDict() + config.data.file_path_data = DATA_PATH + config.data.num_time_steps = 32 + config.data.time_stride = 1 + config.data.batch_size = 8 + config.data.normalize = True + config.data.random_seed = 1 + config.data.tf_lookup_batch_size = 32 + config.data.std = 1.0 + config.data.space_shape = (64, 64, 1) + + config.optimizer = ml_collections.ConfigDict() + config.optimizer.num_train_steps = 1000000 + config.optimizer.initial_lr = 0.0 + config.optimizer.peak_lr = 1e-5 + config.optimizer.warmup_steps = 50000 + config.optimizer.end_lr = 1e-7 + config.optimizer.ema_decay = 0.999 + config.optimizer.ckpt_interval = 1000 + config.optimizer.max_ckpt_to_keep = 5 + config.optimizer.clip_min = 1e-4 + config.optimizer.metric_aggreration_steps = 50 + config.optimizer.eval_every_steps = 1000 + config.optimizer.num_batches_per_eval = 8 + config.optimizer.clip = 1. + config.optimizer.beta1 = 0.99 + + return config + diff --git a/swirl_dynamics/projects/spatiotemporal_modeling/data_utils.py b/swirl_dynamics/projects/spatiotemporal_modeling/data_utils.py new file mode 100644 index 0000000..940c9f3 --- /dev/null +++ b/swirl_dynamics/projects/spatiotemporal_modeling/data_utils.py @@ -0,0 +1,154 @@ +# Copyright 2023 The swirl_dynamics Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Simple dataloader generating a short video of a fluid simulation.""" + +import grain.tensorflow as tfgrain +import jax +import numpy as np +from swirl_dynamics.data import hdf5_utils +from swirl_dynamics.data import tfgrain_transforms as transforms +import tensorflow as tf + +Array = jax.Array + + +def create_loader_from_hdf5( + num_time_steps: int, + time_stride: int, + batch_size: int, + dataset_path: str, + seed: int, + split: str | None = None, + spatial_downsample_factor: int = 1, + normalize: bool = False, + normalize_stats: dict[str, Array] | None = None, + use_time_normalization: bool = False, + tf_lookup_batch_size: int = 4, + tf_lookup_num_parallel_calls: int = -1, + tf_interleaved_shuffle: bool = False, +) -> tuple[tfgrain.TfDataLoader, dict[str, Array | None]]: + """Load pre-computed trajectories dumped to hdf5 file. + + If normalize flag is set, method will also return the mean and std used in + normalization (which are calculated from train split). + + Arguments: + num_time_steps: Number of time steps to include in each trajectory. If set + to -1, use entire trajectory lengths. + time_stride: Stride of trajectory sampling. + batch_size: Batch size returned by dataloader. If set to -1, use entire + dataset size as batch_size. + dataset_path: Absolute path to dataset file. + seed: Random seed to be used in data sampling. + split: Data split - train, eval, test, or None. + spatial_downsample_factor: reduce spatial resolution by factor of x. + normalize: Flag for adding data normalization (subtact mean divide by std.). + normalize_stats: Dictionary with mean and std stats to avoid recomputing. + use_time_normalization: Normalization is performed using a short sequence + as a unit, instead of a single snaphot. + tf_lookup_batch_size: Number of lookup batches (in cache) for grain. + tf_lookup_num_parallel_calls: Number of parallel call for lookups in the + dataset. -1 is set to let grain optimize tha number of calls. + tf_interleaved_shuffle: Using a more localized shuffle instead of a global + suffle of the data. + + Returns: + loader, stats (optional): tuple of dataloader and dictionary containing + mean and std stats (if normalize=True, else dict + contains NoneType values). + """ + snapshots, tspan = hdf5_utils.read_arrays_as_tuple( + dataset_path, (f"{split}/u", f"{split}/t") + ) + if spatial_downsample_factor > 1: + if snapshots.ndim == 3: + snapshots = snapshots[:, :, ::spatial_downsample_factor] + elif snapshots.ndim == 4: + snapshots = snapshots[:, :, ::spatial_downsample_factor, :] + elif snapshots.ndim == 5: + snapshots = snapshots[ + :, :, ::spatial_downsample_factor, ::spatial_downsample_factor, : + ] + else: + raise NotImplementedError( + f"Number of dimensions {snapshots.ndim} not " + "supported for spatial downsampling." + ) + + if normalize: + if normalize_stats is not None: + mean = normalize_stats["mean"] + std = normalize_stats["std"] + else: + if split != "train": + data_for_stats = hdf5_utils.read_single_array(dataset_path, "train/u") + else: + data_for_stats = snapshots + # TODO(lzepedanunez): For the sake of memory perform this in CPU. + if use_time_normalization: + num_trajs, num_frames, nx, ny, d = data_for_stats.shape + num_segments = num_frames // num_time_steps + data_for_stats = data_for_stats[:, :(num_segments * num_time_steps)] + data_for_stats = np.reshape( + data_for_stats, (num_trajs, num_segments, num_time_steps, nx, ny, d) + ) + + mean = np.mean(data_for_stats, axis=(0, 1)) + std = np.std(data_for_stats, axis=(0, 1)) + snapshots -= mean + snapshots /= std + else: + mean, std = None, None + source = tfgrain.TfInMemoryDataSource.from_dataset( + tf.data.Dataset.from_tensor_slices({ + "u": snapshots, # states + }) + ) + # This transform randomly takes a random section from the trajectory with the + # desired length and stride + if num_time_steps == -1: # Use full (downsampled) trajectories + num_time_steps = tspan.shape[1] // time_stride + section_transform = transforms.RandomSection( + feature_names=("u",), + num_steps=num_time_steps, + stride=time_stride, + ) + + rename = transforms.SelectAs(select_features=("u",), as_features=("x",)) + + dataset_transforms = (section_transform, rename) + + # Grain fine-tuning. + tfgrain.config.update( + "tf_lookup_num_parallel_calls", tf_lookup_num_parallel_calls + ) + tfgrain.config.update("tf_interleaved_shuffle", tf_interleaved_shuffle) + tfgrain.config.update("tf_lookup_batch_size", tf_lookup_batch_size) + + if batch_size == -1: # Use full dataset as batch + batch_size = len(source) + loader = tfgrain.TfDataLoader( + source=source, + sampler=tfgrain.TfDefaultIndexSampler( + num_records=len(source), + seed=seed, + num_epochs=None, # loads indefnitely + shuffle=True, + shard_options=tfgrain.ShardByJaxProcess(drop_remainder=True), + ), + transformations=dataset_transforms, + batch_fn=tfgrain.TfBatch(batch_size=batch_size, drop_remainder=False), + ) + return loader, {"mean": mean, "std": std} diff --git a/swirl_dynamics/projects/spatiotemporal_modeling/main.py b/swirl_dynamics/projects/spatiotemporal_modeling/main.py new file mode 100644 index 0000000..658d51e --- /dev/null +++ b/swirl_dynamics/projects/spatiotemporal_modeling/main.py @@ -0,0 +1,198 @@ +# Copyright 2023 The swirl_dynamics Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +r"""The main entry point for running training loops.""" + +import json +from os import path as osp + +from absl import app +from absl import flags +import jax +from ml_collections import config_flags +import optax +from orbax import checkpoint +from swirl_dynamics.lib.diffusion import diffusion +from swirl_dynamics.lib.diffusion import vivit_diffusion +from swirl_dynamics.projects.probabilistic_diffusion import models +from swirl_dynamics.projects.probabilistic_diffusion import trainers +from swirl_dynamics.projects.spatiotemporal_modeling import data_utils +from swirl_dynamics.templates import callbacks +from swirl_dynamics.templates import train +import tensorflow as tf + + +FLAGS = flags.FLAGS + +flags.DEFINE_string("workdir", None, "Directory to store model data.") +config_flags.DEFINE_config_file( + "config", + None, + "File path to the training hyperparameter configuration.", + lock_config=True, +) + + +def main(argv): + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + config = FLAGS.config + # Dump config as json to workdir. + workdir = FLAGS.workdir + if not tf.io.gfile.exists(workdir): + tf.io.gfile.makedirs(workdir) + # Only 0-th process should write the json file to disk, in order to avoid + # race conditions. + if jax.process_index() == 0: + with tf.io.gfile.GFile(name=osp.join(workdir, + "config.json"), mode="w") as f: + conf_json = config.to_json_best_effort() + if isinstance(conf_json, str): # Sometimes `.to_json()` returns string + conf_json = json.loads(conf_json) + json.dump(conf_json, f) + tf.config.experimental.set_visible_devices([], "GPU") + + # Defining experiments through the config file. + schedule = optax.warmup_cosine_decay_schedule( + init_value=config.optimizer.initial_lr, + peak_value=config.optimizer.peak_lr, + warmup_steps=config.optimizer.warmup_steps, + decay_steps=config.optimizer.num_train_steps, + end_value=config.optimizer.end_lr, + ) + + optimizer = optax.chain( + optax.clip(config.optimizer.clip), + optax.adam( + learning_rate=schedule, + b1=config.optimizer.beta1, + ), + ) + + train_dataloader, stats = data_utils.create_loader_from_hdf5( + num_time_steps=config.data.num_time_steps, + time_stride=config.data.time_stride, + batch_size=config.data.batch_size, + spatial_downsample_factor=config.spatial_downsample_factor, + dataset_path=config.data.file_path_data, + seed=config.data.random_seed, + tf_lookup_batch_size=config.data.tf_lookup_batch_size, + split="train", + normalize=config.data.normalize, + ) + + eval_dataloader, _ = data_utils.create_loader_from_hdf5( + num_time_steps=config.data.num_time_steps, + time_stride=config.data.time_stride, + batch_size=config.data.batch_size, + normalize_stats=stats, + spatial_downsample_factor=config.spatial_downsample_factor, + dataset_path=config.data.file_path_data, + tf_lookup_batch_size=config.data.tf_lookup_batch_size, + seed=config.data.random_seed, + split="eval", + normalize=config.data.normalize, + ) + + # Setting up the denoiser neural network. + denoiser_model = vivit_diffusion.PreconditionedDenoiser( + mlp_dim=config.model.mlp_dim, + num_layers=config.model.num_layers, + num_heads=config.model.num_heads, + output_features=1, + noise_embed_dim=config.model.noise_embed_dim, + patches=config.model.patches, + hidden_size=config.model.hidden_size, + temporal_encoding_config=config.model.temporal_encoding_config, + attention_config=config.model.attention_config, + positional_embedding=config.model.positional_embedding, + sigma_data=1.0, # standard deviation of the entire dataset. + ) + + if config.model.diffusion_scheme == "variance_exploding": + diffusion_scheme = diffusion.Diffusion.create_variance_exploding( + sigma=diffusion.tangent_noise_schedule( + clip_max=80.0, start=-1.5, end=1.5 + ), + data_std=config.data.std, + ) + elif config.model.diffusion_scheme == "variance_preserving": + diffusion_scheme = diffusion.Diffusion.create_variance_preserving( + sigma=diffusion.tangent_noise_schedule(), + data_std=config.data.std, + ) + else: + raise ValueError( + f"Unknown diffusion scheme: {config.model.diffusion_scheme}" + ) + + model = models.DenoisingModel( + input_shape=( + config.data.num_time_steps, + config.data.space_shape[0] // config.spatial_downsample_factor, + config.data.space_shape[1] // config.spatial_downsample_factor, + config.data.space_shape[2], + ), # This must agree with the expected sample shape. + denoiser=denoiser_model, + noise_sampling=diffusion.log_uniform_sampling( + diffusion_scheme, # pylint: disable=undefined-variable + clip_min=config.optimizer.clip_min, + uniform_grid=True, + ), + noise_weighting=diffusion.edm_weighting(data_std=1.0), + ) + + # Defining the trainer. + trainer = trainers.DenoisingTrainer( + model=model, + rng=jax.random.PRNGKey(888), + optimizer=optimizer, + # This option is to minimize the colorshift. + ema_decay=config.optimizer.ema_decay, + ) + + # Setting up checkpointing. + ckpt_options = checkpoint.CheckpointManagerOptions( + save_interval_steps=config.save_interval_steps, + max_to_keep=config.max_checkpoints_to_keep, + ) + + # Run training loop. + train.run( + train_dataloader=train_dataloader, + trainer=trainer, + workdir=workdir, + total_train_steps=config.optimizer.num_train_steps, + metric_aggregation_steps=config.optimizer.metric_aggreration_steps, # 30 + eval_dataloader=eval_dataloader, + eval_every_steps=config.optimizer.eval_every_steps, + num_batches_per_eval=config.optimizer.num_batches_per_eval, + callbacks=( + # This callback displays the training progress in a tqdm bar. + callbacks.TqdmProgressBar( + total_train_steps=config.optimizer.num_train_steps, + train_monitors=("train_loss",), + ), + # This callback saves model checkpoint periodically. + callbacks.TrainStateCheckpoint( + base_dir=workdir, + options=ckpt_options, + ), + # TODO(lzepedanunez) add a plot callback. + ), + ) + + +if __name__ == "__main__": + app.run(main)