diff --git a/swirl_dynamics/projects/ergodic/choices.py b/swirl_dynamics/projects/ergodic/choices.py index a7fd049..a3a56b7 100644 --- a/swirl_dynamics/projects/ergodic/choices.py +++ b/swirl_dynamics/projects/ergodic/choices.py @@ -30,10 +30,12 @@ from swirl_dynamics.lib.networks import nonlinear_fourier from swirl_dynamics.lib.solvers import ode from swirl_dynamics.projects.ergodic import measure_distances +from swirl_dynamics.projects.ergodic import rollout_weighting Array = jax.Array MeasureDistFn = Callable[[Array, Array], float | Array] +RolloutWeightingFn = Callable[[int], Array] class Experiment(enum.Enum): @@ -156,6 +158,44 @@ def dispatch(self, conf: ml_collections.ConfigDict) -> nn.Module: out_channels=conf.out_channels, num_modes=conf.num_modes, width=conf.width, - fft_norm=conf.fft_norm + fft_norm=conf.fft_norm, + ) + raise ValueError() + + +class RolloutWeighting(enum.Enum): + """Rollout weighting choices.""" + + GEOMETRIC = "geometric" + INV_SQRT = "inv_sqrt" + INV_SQUARED = "inv_squared" + LINEAR = "linear" + NO_WEIGHT = "no_weight" + + def dispatch(self, conf: ml_collections.ConfigDict) -> RolloutWeightingFn: + """Dispatch rollout weighting.""" + if self.value == RolloutWeighting.GEOMETRIC.value: + return functools.partial( + rollout_weighting.geometric, + r=conf.rollout_weighting_r, + clip=conf.rollout_weighting_clip + ) + if self.value == RolloutWeighting.INV_SQRT.value: + return functools.partial( + rollout_weighting.inverse_sqrt, + clip=conf.rollout_weighting_clip + ) + if self.value == RolloutWeighting.INV_SQUARED.value: + return functools.partial( + rollout_weighting.inverse_squared, + clip=conf.rollout_weighting_clip + ) + if self.value == RolloutWeighting.LINEAR.value: + return functools.partial( + rollout_weighting.linear, + m=conf.rollout_weighting_m, + clip=conf.rollout_weighting_clip ) + if self.value == RolloutWeighting.NO_WEIGHT.value: + return rollout_weighting.no_weight raise ValueError() diff --git a/swirl_dynamics/projects/ergodic/colabs/demo.ipynb b/swirl_dynamics/projects/ergodic/colabs/demo.ipynb index 02ab8c4..51512be 100644 --- a/swirl_dynamics/projects/ergodic/colabs/demo.ipynb +++ b/swirl_dynamics/projects/ergodic/colabs/demo.ipynb @@ -128,13 +128,13 @@ }, "outputs": [], "source": [ - "experiment = \"ks_1d\" #@param choices=['lorenz63', 'ks_1d', 'ns_2d']\n", - "batch_size = 128 #@param {type:\"integer\"}\n", + "experiment = \"ns_2d\" #@param choices=['lorenz63', 'ks_1d', 'ns_2d']\n", + "batch_size = 512 #@param {type:\"integer\"}\n", "measure_dist_type = \"MMD\" #@param choices=['MMD', 'SD']\n", "normalize = False #@param {type:\"boolean\"}\n", "add_noise = False #@param {type:\"boolean\"}\n", "use_curriculum = True #@param {type:\"boolean\"}\n", - "use_pushfwd = False #@param {type:\"boolean\"}\n", + "use_pushfwd = True #@param {type:\"boolean\"}\n", "measure_dist_lambda = 0.0 #@param {type:\"number\"}\n", "measure_dist_k_lambda = 0.0 #@param {type:\"number\"}\n", "display_config = True #@param {type:\"boolean\"}\n", @@ -183,7 +183,6 @@ "cell_type": "code", "execution_count": null, "metadata": { - "cellView": "form", "id": "3oO4PJjqk-6i" }, "outputs": [], @@ -272,6 +271,9 @@ "\n", "# Trainer\n", "trainer_config = stable_ar.StableARTrainerConfig(\n", + " rollout_weighting=choices.RolloutWeighting(\n", + " config.rollout_weighting\n", + " ).dispatch(config),\n", " num_rollout_steps=config.num_rollout_steps,\n", " num_lookback_steps=config.num_lookback_steps,\n", " add_noise=config.add_noise,\n", @@ -311,7 +313,7 @@ " ),\n", " callbacks.TqdmProgressBar(\n", " total_train_steps=config.train_steps,\n", - " train_monitors=[\"rollout\", \"loss\", \"measure_dist\", \"measure_dist_k\"],\n", + " train_monitors=[\"rollout\", \"loss\", \"measure_dist\", \"measure_dist_k\", \"max_rollout_decay\"],\n", " eval_monitors=[\"sd\"],\n", " ),\n", " fig_callback_cls()\n", @@ -323,7 +325,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "cxxB4OZBuAA2" + "id": "-8lZXGG4wFks" }, "outputs": [], "source": [ @@ -333,7 +335,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "2AT2YThCyNPt" + "id": "fu6IxOuFhULW" }, "outputs": [], "source": [] @@ -345,7 +347,7 @@ "d_9f5Fwifhd9" ], "last_runtime": { - "build_target": "//learning/grp/tools/ml_python:ml_notebook", + "build_target": "//learning/deepmind/dm_python:dm_notebook3_tpu", "kind": "private" }, "private_outputs": true, diff --git a/swirl_dynamics/projects/ergodic/configs/ks_1d.py b/swirl_dynamics/projects/ergodic/configs/ks_1d.py index 6554cc7..5856eec 100644 --- a/swirl_dynamics/projects/ergodic/configs/ks_1d.py +++ b/swirl_dynamics/projects/ergodic/configs/ks_1d.py @@ -28,11 +28,12 @@ def get_config(): config = ml_collections.ConfigDict() config.experiment = 'ks_1d' # Train params - config.train_steps = 50_000 + config.train_steps = 300_000 config.seed = 42 config.lr = 1e-4 + config.use_lr_scheduler = True config.metric_aggregation_steps = 50 - config.save_interval_steps = 5_000 + config.save_interval_steps = 30_000 config.max_checkpoints_to_keep = 10 # Data params config.batch_size = 128 @@ -50,20 +51,19 @@ def get_config(): config.order_sobolev_norm = 1 config.noise_level = 0.0 - # TODO(yairschiff): Split different models into separate configs # Model params + config.model = 'PeriodicConvNetModel' # 'Fno' + # TODO(yairschiff): Split CNN and FNO into separate configs ########### PeriodicConvNetModel ################ - # config.model = 'PeriodicConvNetModel' - # config.latent_dim = 48 - # config.num_levels = 4 - # config.num_processors = 4 - # config.encoder_kernel_size = (5,) - # config.decoder_kernel_size = (5,) - # config.processor_kernel_size = (5,) - # config.padding = 'CIRCULAR' - # config.is_input_residual = True + config.latent_dim = 48 + config.num_levels = 4 + config.num_processors = 4 + config.encoder_kernel_size = (5,) + config.decoder_kernel_size = (5,) + config.processor_kernel_size = (5,) + config.padding = 'CIRCULAR' + config.is_input_residual = True ########### FNO ################ - config.model = 'Fno' config.out_channels = 1 config.hidden_channels = 64 config.num_modes = (24,) @@ -77,14 +77,17 @@ def get_config(): config.num_lookback_steps = 1 # Update num_time_steps and integrator based on num_lookback_steps setting - config.num_time_steps += config.num_lookback_steps - 1 - if config.num_lookback_steps > 1: + config.num_time_steps += config.get_ref('num_lookback_steps') - 1 + if config.get_ref('num_lookback_steps') > 1: config.integrator = 'MultiStepDirect' else: config.integrator = 'OneStepDirect' # Trainer params + config.rollout_weighting = 'geometric' + config.rollout_weighting_r = 0.9 + config.rollout_weighting_clip = 10e-4 config.num_rollout_steps = 1 - config.train_steps_per_cycle = 5_000 + config.train_steps_per_cycle = 60_000 config.time_steps_increase_per_cycle = 1 config.use_curriculum = False # Sweepable config.use_pushfwd = False # Sweepable @@ -106,6 +109,8 @@ def skip( if not use_curriculum and use_pushfwd: return True + if measure_dist_lambda > 0.0 and measure_dist_k_lambda == 0.0: + return True if ( measure_dist_type == 'SD' and measure_dist_lambda == 0.0 @@ -119,87 +124,79 @@ def skip( # use option --sweep=False in the command line to avoid sweeping def sweep(add): """Define param sweep.""" - for seed in [42]: - for normalize in [False, True]: - for batch_size in [128]: - for lr in [1e-4]: - for use_curriculum in [False, True]: - for use_pushfwd in [False, True]: - for measure_dist_type in ['MMD', 'SD']: - for measure_dist_lambda in [0.0, 1.0]: - for measure_dist_k_lambda in [0.0, 1.0, 100.0]: - if use_curriculum: - train_steps_per_cycle = 50_000 - time_steps_increase_per_cycle = 1 - else: - train_steps_per_cycle = 0 - time_steps_increase_per_cycle = 0 - if skip( - use_curriculum, - use_pushfwd, - measure_dist_lambda, - measure_dist_k_lambda, - measure_dist_type, - ): - continue - add( - seed=seed, - batch_size=batch_size, - normalize=normalize, - lr=lr, - measure_dist_type=measure_dist_type, - train_steps_per_cycle=train_steps_per_cycle, - time_steps_increase_per_cycle=time_steps_increase_per_cycle, - use_curriculum=use_curriculum, - use_pushfwd=use_pushfwd, - measure_dist_lambda=measure_dist_lambda, - measure_dist_k_lambda=measure_dist_k_lambda, - ) + # pylint: disable=line-too-long + for seed in [1, 11, 21, 31, 42]: + for normalize in [True]: + for model in ['PeriodicConvNetModel']: + for batch_size in [32, 64, 128, 256]: + for lr in [5e-4]: + for use_curriculum in [True]: + for use_pushfwd in [True]: + for measure_dist_type in ['MMD', 'SD']: + for measure_dist_lambda in [0.0, 1.0]: + for measure_dist_k_lambda in [0.0, 1.0, 10.0, 100.0, 1000.0]: + if use_curriculum: + train_steps_per_cycle = 60_000 + time_steps_increase_per_cycle = 1 + else: + train_steps_per_cycle = 0 + time_steps_increase_per_cycle = 0 + if skip( + use_curriculum, + use_pushfwd, + measure_dist_lambda, + measure_dist_k_lambda, + measure_dist_type, + ): + continue + add( + seed=seed, + normalize=normalize, + model=model, + batch_size=batch_size, + lr=lr, + measure_dist_type=measure_dist_type, + train_steps_per_cycle=train_steps_per_cycle, + time_steps_increase_per_cycle=time_steps_increase_per_cycle, + use_curriculum=use_curriculum, + use_pushfwd=use_pushfwd, + measure_dist_lambda=measure_dist_lambda, + measure_dist_k_lambda=measure_dist_k_lambda, + ) +# TODO(yairschiff): Ablation! # def sweep(add): # """Define param sweep.""" -# for seed in [21, 42, 84]: -# for measure_dist_type in ['MMD', 'SD']: -# for batch_size in [32, 64, 128, 256]: -# for lr in [1e-3, 1e-4, 1e-5]: -# # Skipping 1-step objective -# for use_curriculum in [True]: # [False, True]: -# # Running grid search on just Pfwd -# for use_pushfwd in [True]: # [False, True]: -# # Skipping all x0 regs -# for regularize_measure_dist in [False]: # [False, True]: -# for regularize_measure_dist_k in [False, True]: -# for measure_dist_lambda in [0.0, 1.0, 100.0]: -# if use_curriculum: -# train_steps_per_cycle = 50_000 -# time_steps_increase_per_cycle = 1 -# else: -# train_steps_per_cycle = 0 -# time_steps_increase_per_cycle = 0 -# if skip( -# use_curriculum, -# use_pushfwd, -# regularize_measure_dist, -# regularize_measure_dist_k, -# measure_dist_lambda, -# measure_dist_type -# ): -# continue -# add( -# num_time_steps=61, -# train_steps=3_000_000, -# save_interval_steps=250_000, -# max_checkpoints_to_keep=3_000_000//250_000, -# seed=seed, -# batch_size=batch_size, -# lr=lr, -# measure_dist_type=measure_dist_type, -# train_steps_per_cycle=train_steps_per_cycle, -# time_steps_increase_per_cycle=time_steps_increase_per_cycle, -# use_curriculum=use_curriculum, -# use_pushfwd=use_pushfwd, -# regularize_measure_dist=regularize_measure_dist, -# regularize_measure_dist_k=regularize_measure_dist_k, -# measure_dist_lambda=measure_dist_lambda, -# ) +# # pylint: disable=line-too-long +# for seed in [1, 11, 21, 31, 42]: +# for normalize in [True]: +# for model in ['PeriodicConvNetModel']: +# for batch_size in [32, 64, 128, 256, 512]: +# for lr in [5e-4]: +# for use_curriculum in [True]: +# for use_pushfwd in [True]: +# for measure_dist_type, measure_dist_lambda, measure_dist_k_lambda in [('MMD', 0.0, 10.0), ('SD', 0.0, 10.0)]: #[('MMD', 0.0, 0.0), ('MMD', 1.0, 1000.0), ('SD', 0.0, 1.0)]: +# train_steps_per_cycle = 60_000 +# time_steps_increase_per_cycle = 1 +# train_steps = 300_000 +# max_checkpoints_to_keep = 10 +# num_time_steps = 11 +# add( +# seed=seed, +# train_steps=train_steps, +# max_checkpoints_to_keep=max_checkpoints_to_keep, +# num_time_steps=num_time_steps, +# normalize=normalize, +# model=model, +# batch_size=batch_size, +# lr=lr, +# measure_dist_type=measure_dist_type, +# train_steps_per_cycle=train_steps_per_cycle, +# time_steps_increase_per_cycle=time_steps_increase_per_cycle, +# use_curriculum=use_curriculum, +# use_pushfwd=use_pushfwd, +# measure_dist_lambda=measure_dist_lambda, +# measure_dist_k_lambda=measure_dist_k_lambda, +# ) +# # pylint: enable=line-too-long diff --git a/swirl_dynamics/projects/ergodic/configs/lorenz63.py b/swirl_dynamics/projects/ergodic/configs/lorenz63.py index 28ec202..6855a5f 100644 --- a/swirl_dynamics/projects/ergodic/configs/lorenz63.py +++ b/swirl_dynamics/projects/ergodic/configs/lorenz63.py @@ -31,6 +31,7 @@ def get_config(): config.train_steps = 500_000 config.seed = 42 config.lr = 1e-4 + config.use_lr_scheduler = False config.metric_aggregation_steps = 50 config.save_interval_steps = 50_000 config.max_checkpoints_to_keep = 10 @@ -58,6 +59,9 @@ def get_config(): # Update num_time_steps based on num_lookback_steps setting config.num_time_steps += config.num_lookback_steps - 1 # Trainer params + config.rollout_weighting = 'geometric' + config.rollout_weighting_r = 0.1 + config.rollout_weighting_clip = 10e-8 config.num_rollout_steps = 1 config.train_steps_per_cycle = 50_000 config.time_steps_increase_per_cycle = 0 @@ -95,42 +99,43 @@ def skip( # TODO(yairschiff): Refactor sweeps and experiment definition to use gin. def sweep(add): """Define param sweep.""" - for seed in [21, 42]: + # pylint: disable=line-too-long + for seed in [1, 11, 21, 31, 42]: for batch_size in [2048]: - for time_stride in [50]: - for normalize in [True]: - for measure_dist_type in ['MMD', 'SD']: - for use_curriculum in [True]: - for use_pushfwd in [False]: - for measure_dist_lambda in [0.0]: #, 1.0]: - for measure_dist_k_lambda in [0.0]: #, 1.0, 1000.0]: - if use_curriculum: - train_steps = 2_000_000 - train_steps_per_cycle = 200_000 - time_steps_increase_per_cycle = 1 - else: - train_steps = 500_000 - train_steps_per_cycle = 0 - time_steps_increase_per_cycle = 0 - if skip( - use_curriculum, - use_pushfwd, - measure_dist_lambda, - measure_dist_k_lambda, - measure_dist_type, - ): - continue - add( - seed=seed, - batch_size=batch_size, - time_stride=time_stride, - normalize=normalize, - measure_dist_type=measure_dist_type, - train_steps=train_steps, - train_steps_per_cycle=train_steps_per_cycle, - time_steps_increase_per_cycle=time_steps_increase_per_cycle, - use_curriculum=use_curriculum, - use_pushfwd=use_pushfwd, - measure_dist_lambda=measure_dist_lambda, - measure_dist_k_lambda=measure_dist_k_lambda, - ) + for lr in [1e-4]: + for time_stride in [40]: + for normalize in [True]: + for measure_dist_type in ['MMD', 'SD']: + for use_curriculum in [True]: + for use_pushfwd in [False, True]: + for measure_dist_lambda in [0.0, 1.0]: + for measure_dist_k_lambda in [0.0, 1.0, 10.0, 100.0, 1000.0]: + if use_curriculum: + train_steps_per_cycle = 50_000 + time_steps_increase_per_cycle = 1 + else: + train_steps_per_cycle = 0 + time_steps_increase_per_cycle = 0 + if skip( + use_curriculum, + use_pushfwd, + measure_dist_lambda, + measure_dist_k_lambda, + measure_dist_type, + ): + continue + add( + seed=seed, + batch_size=batch_size, + lr=lr, + time_stride=time_stride, + normalize=normalize, + measure_dist_type=measure_dist_type, + train_steps_per_cycle=train_steps_per_cycle, + time_steps_increase_per_cycle=time_steps_increase_per_cycle, + use_curriculum=use_curriculum, + use_pushfwd=use_pushfwd, + measure_dist_lambda=measure_dist_lambda, + measure_dist_k_lambda=measure_dist_k_lambda, + ) + # pylint: enable=line-too-long diff --git a/swirl_dynamics/projects/ergodic/configs/ns_2d.py b/swirl_dynamics/projects/ergodic/configs/ns_2d.py index 06ae6dd..7177f1a 100644 --- a/swirl_dynamics/projects/ergodic/configs/ns_2d.py +++ b/swirl_dynamics/projects/ergodic/configs/ns_2d.py @@ -19,8 +19,8 @@ import ml_collections # pylint: disable=line-too-long -DATA_PATH = '/datasets/hdf5/pde/2d/ns/ns_trajectories_from_caltech.hdf5' -# DATA_PATH = '/datasets/hdf5/pde/2d/ns/attractor_spectral_grid_256_spatial_downsample_4_dt_0.001_v0_3_warmup_40.0_t_final_200.0_nu_0.001_n_samples_2000_ntraj_train_128_ntraj_eval_32_ntraj_test_32_drag_0.1_wave_number_4_random_seeds_combined_4.hdf5' +# DATA_PATH = '/datasets/hdf5/pde/2d/ns/ns_trajectories_from_caltech.hdf5' +DATA_PATH = '/datasets/hdf5/pde/2d/ns/attractor_spectral_grid_256_spatial_downsample_4_dt_0.001_v0_3_warmup_40.0_t_final_200.0_nu_0.001_n_samples_2000_ntraj_train_256_ntraj_eval_32_ntraj_test_32_drag_0.1_wave_number_4_random_seeds_combined_4.hdf5' # pylint: enable=line-too-long @@ -31,15 +31,16 @@ def get_config(): # Train params config.train_steps = 360_000 config.seed = 42 - config.lr = 5e-5 + config.lr = 5e-4 + config.use_lr_scheduler = True config.metric_aggregation_steps = 50 config.save_interval_steps = 36_000 config.max_checkpoints_to_keep = 10 - config.use_sobolev_norm = True + config.use_sobolev_norm = False config.order_sobolev_norm = 1 # Data params - config.batch_size = 50 + config.batch_size = 256 # num_time_steps is length of ground truth trajectories in each batch from # dataloader (this differs from num_time_steps in Trainer.preprocess_batch # functions that corresponds to the len of ground truth trajectories to @@ -53,47 +54,50 @@ def get_config(): config.noise_level = 0.0 # Model params + # TODO(yairschiff): Split CNN and FNO into separate configs + config.model = 'PeriodicConvNetModel' # 'Fno' 'Fno2d' ########### PeriodicConvNetModel ################ - config.model = 'PeriodicConvNetModel' - config.latent_dim = 96 - config.num_levels = 4 - config.num_processors = 4 - config.encoder_kernel_size = (3, 3) - config.decoder_kernel_size = (3, 3) - config.processor_kernel_size = (3, 3) + config.latent_dim = 16 + config.num_levels = 2 + config.num_processors = 2 + config.encoder_kernel_size = (5, 5) + config.decoder_kernel_size = (5, 5) + config.processor_kernel_size = (5, 5) config.padding = 'CIRCULAR' config.is_input_residual = True - # ########### FNO ################ - # config.num_lookback_steps = 1 - # config.model = 'Fno' - # config.out_channels = 1 - # config.hidden_channels = 64 - # config.num_modes = (20, 20) - # config.lifting_channels = 256 - # config.projection_channels = 256 - # config.num_blocks = 4 - # config.layers_per_block = 2 - # config.block_skip_type = 'identity' - # config.fft_norm = 'forward' - # config.separable = False + ############ FNO ################ + config.out_channels = 1 + config.hidden_channels = 64 + config.num_modes = (20, 20) + config.lifting_channels = 256 + config.projection_channels = 256 + config.num_blocks = 4 + config.layers_per_block = 2 + config.block_skip_type = 'identity' + config.fft_norm = 'forward' + config.separable = False ########### FNO 2D ############### - # config.model = 'Fno2d' - # config.out_channels = 1 - # config.num_modes = (20, 20) - # config.width = 128 - # config.fft_norm = 'ortho' + config.out_channels = 1 + config.num_modes = (20, 20) + config.width = 128 + config.fft_norm = 'ortho' config.num_lookback_steps = 1 # Update num_time_steps and integrator based on num_lookback_steps setting - if config.num_lookback_steps > 1: + # Update num_time_steps and integrator based on num_lookback_steps setting + config.num_time_steps += config.get_ref('num_lookback_steps') - 1 + if config.get_ref('num_lookback_steps') > 1: config.integrator = 'MultiStepDirect' else: config.integrator = 'OneStepDirect' # Trainer params + config.rollout_weighting = 'geometric' + config.rollout_weighting_r = 0.9 + config.rollout_weighting_clip = 10e-4 config.num_rollout_steps = 1 - config.train_steps_per_cycle = 0 + config.train_steps_per_cycle = 72_000 config.time_steps_increase_per_cycle = 1 config.use_curriculum = False # Sweepable config.use_pushfwd = False # Sweepable @@ -115,6 +119,8 @@ def skip( if not use_curriculum and use_pushfwd: return True + if measure_dist_lambda > 0.0 and measure_dist_k_lambda == 0.0: + return True if ( measure_dist_type == 'SD' and measure_dist_lambda == 0.0 @@ -124,19 +130,20 @@ def skip( return False + # pylint: disable=line-too-long # TODO(yairschiff): Refactor sweeps and experiment definition to use gin. # use option --sweep=False in the command line to avoid sweeping def sweep(add): """Define param sweep.""" - for seed in [42]: - for normalize in [False, True]: - for measure_dist_type in ['MMD', 'SD']: - for batch_size in [50]: + for seed in [1, 11, 21]: + for normalize in [True]: + for measure_dist_type in ['MMD']: + for batch_size in [64, 128]: for lr in [5e-4]: - for use_curriculum in [False]: - for use_pushfwd in [False]: + for use_curriculum in [False, True]: + for use_pushfwd in [False, True]: for measure_dist_lambda in [0.0, 1.0]: - for measure_dist_k_lambda in [0.0, 1.0, 100.0]: + for measure_dist_k_lambda in [0.0, 1.0, 10.0, 100.0, 1000.0, 10000.0]: if use_curriculum: train_steps_per_cycle = 72_000 time_steps_increase_per_cycle = 1 @@ -164,3 +171,4 @@ def sweep(add): measure_dist_lambda=measure_dist_lambda, measure_dist_k_lambda=measure_dist_k_lambda, ) + # pylint: enable=line-too-long diff --git a/swirl_dynamics/projects/ergodic/generate_traj.py b/swirl_dynamics/projects/ergodic/generate_traj.py new file mode 100644 index 0000000..7116790 --- /dev/null +++ b/swirl_dynamics/projects/ergodic/generate_traj.py @@ -0,0 +1,200 @@ +# Copyright 2023 The swirl_dynamics Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generate predict trajectory. + +Helper script for generating trajectories from pre-trained models +""" + +import functools +import json +from os import path as osp + +from absl import app +from absl import flags +from jax import numpy as jnp +import numpy as np +from orbax import checkpoint +import pandas as pd +from swirl_dynamics.data import utils as data_utils +from swirl_dynamics.lib.solvers import utils as solver_utils +from swirl_dynamics.projects.ergodic import choices +import tensorflow as tf +import tqdm.auto as tqdm + + +FLAGS = flags.FLAGS +flags.DEFINE_string("exp_dir", None, "Path to experiment with trained models.") + + +def create_train_name(args_dict): + """Create name from args.""" + model_display_dict = { + "PeriodicConvNetModel": "CNN", + "Fno": "FNO", + "Fno2d": "MNO", + } + use_curriculum = args_dict["use_curriculum"] + use_pushfwd = args_dict["use_pushfwd"] + measure_dist_type = args_dict["measure_dist_type"] + measure_dist_lambda = args_dict["measure_dist_lambda"] + measure_dist_k_lambda = args_dict["measure_dist_k_lambda"] + model_name = args_dict["model"] + if use_curriculum: + if use_pushfwd: + train_name = "Pfwd" + else: + train_name = "Curr" + else: + train_name = "1-step" + train_name_base = train_name + train_name += f" {model_display_dict[model_name]}" + if measure_dist_lambda > 0.0 or measure_dist_k_lambda > 0.0: + train_name += f" {measure_dist_type}" + train_name += f" λ1={int(measure_dist_lambda)}" + train_name += f", λ2={int(measure_dist_k_lambda)}" + return train_name_base, train_name + + +# Parse dirs +def parse_dir(exp_dir): + """Parse directory to load arguments json.""" + exps = {} + cnt = 0 + if tf.io.gfile.exists(exp_dir): + dirs = tf.io.gfile.listdir(exp_dir) + else: + raise FileNotFoundError(f"Could not list directory: {exp_dir}.") + for d in dirs: + if tf.io.gfile.exists(osp.join(exp_dir, d, "config.json")): + with tf.io.gfile.GFile(osp.join(exp_dir, d, "config.json"), "r") as f: + args = json.load(f) + if isinstance(args, str): + args = json.loads(args) + train_name_base, train_name = create_train_name(args) + args["ckpt_path"] = osp.join(exp_dir, d, "checkpoints") + args["traj_path"] = osp.join(exp_dir, d, "trajectories") + args["train_name_base"] = train_name_base + args["train_name"] = train_name + cnt += 1 + exps[cnt] = args + else: + continue + return pd.DataFrame.from_dict(exps, orient="index").sort_index() + + +def generate_pred_traj(exps_df, all_steps, dt, trajs, mean=None, std=None): + """Generate predicted trajectories and save to file.""" + pbar = tqdm.tqdm(exps_df.iterrows(), total=len(exps_df), desc="Exps") + # cnt = 0 + # skipped = 0 + for r in pbar: + first_step = r[1]["save_interval_steps"] + total_steps = r[1]["train_steps"] + save_every = r[1]["save_interval_steps"] + + train_name = r[1]["train_name"] + seed = r[1]["seed"] + ckpt_dir = r[1]["ckpt_path"] + integrator_choice = r[1]["integrator"] + model_choice = r[1]["model"] + normalize = r[1]["normalize"] + batch_size = r[1]["batch_size"] + ckpt_pbar = tqdm.tqdm( + range(total_steps, first_step - 1, -save_every), desc="Ckpts" + ) + cnt = 0 + skipped = 0 + for trained_steps in ckpt_pbar: + mngr = checkpoint.CheckpointManager( + ckpt_dir, checkpoint.PyTreeCheckpointer() + ) + print( + f"{train_name}; Bsz: {batch_size}, seed: {seed};" + f" {trained_steps:,d} steps", + end="; ", + ) + traj_dir = r[1]["traj_path"] + traj_file = osp.join(traj_dir, f"pred_traj_step={trained_steps}.hdf5") + + if tf.io.gfile.exists(traj_file): + print(f"File exists: {traj_file}.") + else: + if not tf.io.gfile.exists(osp.join(ckpt_dir, str(trained_steps))): + print("Skipping! Ckpt file does not exist.") + skipped += 1 + cnt += 1 + ckpt_pbar.set_postfix({"Skipped": f"{skipped} / {cnt}"}) + continue + params = mngr.restore(step=trained_steps)["params"] + if not tf.io.gfile.exists(traj_dir): + tf.io.gfile.makedirs(traj_dir) + integrator = choices.Integrator(integrator_choice).dispatch() + model = choices.Model(model_choice).dispatch(r[1]) + inference_model = functools.partial( + integrator, + solver_utils.nn_module_to_ode_dynamics(model), + params=dict(params=params), + ) + ics = trajs[:, 0, ...] + if normalize: + ics -= mean + ics /= std + pt = inference_model(ics, np.arange(all_steps) * dt) + if normalize: + pt *= std + pt += mean + del params + print("Generated.", end=" ") + data_utils.save_dict_to_hdf5(traj_file, {"pred_traj": pt}) + print(f"Saved to file {traj_file}.") + cnt += 1 + + +def main(argv): + if len(argv) > 1: + raise app.UsageError("Too many command-line arguments.") + exp_dir = FLAGS.exp_dir + exps_df = parse_dir(exp_dir) + dataset_path = exps_df["dataset_path"].unique().tolist()[0] + trajs, tspan = data_utils.read_nparray_from_hdf5( + dataset_path, + "test/u", + "test/t", + ) + all_steps = trajs.shape[1] + dt = jnp.mean(jnp.diff(tspan, axis=1)) + print("Num traj:", trajs.shape[0]) + print("traj length (steps):", all_steps) + print("dt:", dt) + spatial_downsample = exps_df["spatial_downsample_factor"].tolist()[0] + if trajs.ndim == 4: + trajs = trajs[:, :, ::spatial_downsample, :] + elif trajs.ndim == 5: + trajs = trajs[:, :, ::spatial_downsample, ::spatial_downsample, :] + print("Spatial resolution:", trajs.shape[2:-1]) + + train_snapshots = data_utils.read_nparray_from_hdf5(dataset_path, "train/u")[ + 0 + ] + mean = jnp.mean(train_snapshots, axis=(0, 1)) + std = jnp.std(train_snapshots, axis=(0, 1)) + del train_snapshots + print("mean", mean[:10]) + print("std", std[:10]) + generate_pred_traj(exps_df, all_steps, dt, trajs=trajs, mean=mean, std=std) + + +if __name__ == "__main__": + app.run(main) diff --git a/swirl_dynamics/projects/ergodic/ks_1d.py b/swirl_dynamics/projects/ergodic/ks_1d.py index e571141..de72f30 100644 --- a/swirl_dynamics/projects/ergodic/ks_1d.py +++ b/swirl_dynamics/projects/ergodic/ks_1d.py @@ -103,6 +103,11 @@ def plot_trajectories( class KS1DPlotFigures(stable_ar.PlotFigures): """Kuramoto Sivashinsky 1D plotting.""" + def __init__(self, cos_sim_plot_steps: int = 500): + super().__init__() + # Correlation breaks down early, do not need all the steps + self.cos_sim_plot_steps = cos_sim_plot_steps + def on_eval_batches_end( self, trainer: callbacks.Trainer, eval_metrics: Mapping[str, Array] ) -> None: @@ -126,7 +131,7 @@ def on_eval_batches_end( figs.update( utils.plot_cos_sims( dt=dt, - traj_length=traj_length, + traj_length=min(traj_length, self.cos_sim_plot_steps), trajs=eval_metrics["all_trajs"]["trajs"], pred_trajs=eval_metrics["all_trajs"]["pred_trajs"], ) diff --git a/swirl_dynamics/projects/ergodic/lorenz63.py b/swirl_dynamics/projects/ergodic/lorenz63.py index 36297d6..5dfc741 100644 --- a/swirl_dynamics/projects/ergodic/lorenz63.py +++ b/swirl_dynamics/projects/ergodic/lorenz63.py @@ -209,7 +209,7 @@ def plot_correlations(dt, traj_length, trajs, pred_trajs): class Lorenz63PlotFigures(stable_ar.PlotFigures): """Lorenz 63 plotting.""" - def __init__(self, corr_plot_steps: int = 2000): + def __init__(self, corr_plot_steps: int = 20): super().__init__() # Correlation breaks down early, do not need all the steps self.corr_plot_steps = corr_plot_steps diff --git a/swirl_dynamics/projects/ergodic/main.py b/swirl_dynamics/projects/ergodic/main.py index 0bdf0e7..fc73a13 100644 --- a/swirl_dynamics/projects/ergodic/main.py +++ b/swirl_dynamics/projects/ergodic/main.py @@ -73,14 +73,13 @@ def main(argv): if experiment == choices.Experiment.L63: fig_callback_cls = lorenz63.Lorenz63PlotFigures state_dims = (3 // config.spatial_downsample_factor,) - optimizer = optax.adam(learning_rate=config.lr) elif experiment == choices.Experiment.KS_1D: fig_callback_cls = ks_1d.KS1DPlotFigures state_dims = ( 512 // config.spatial_downsample_factor, config.num_lookback_steps, ) - optimizer = optax.adam(learning_rate=config.lr) + elif experiment == choices.Experiment.NS_2D: fig_callback_cls = ns_2d.NS2dPlotFigures # TODO(yairschiff): This state dim is temporary for FNO data, should be 256 @@ -89,16 +88,20 @@ def main(argv): 64 // config.spatial_downsample_factor, config.num_lookback_steps, ) + else: + raise NotImplementedError(f"Unknown experiment: {config.experiment}") + + if config.use_lr_scheduler: optimizer = optax.adam( learning_rate=optax.exponential_decay( init_value=config.lr, - transition_steps=72_000, + transition_steps=config.train_steps_per_cycle, decay_rate=0.5, staircase=True, ) ) else: - raise NotImplementedError(f"Unknown experiment: {config.experiment}") + optimizer = optax.adam(learning_rate=config.lr) # Dataloaders if "use_tfds" in config and config.use_tfds: @@ -167,10 +170,14 @@ def main(argv): # Trainer trainer_config = stable_ar.StableARTrainerConfig( + rollout_weighting=choices.RolloutWeighting( + config.rollout_weighting + ).dispatch(config), num_rollout_steps=config.num_rollout_steps, num_lookback_steps=config.num_lookback_steps, add_noise=config.add_noise, use_curriculum=config.use_curriculum, + use_pushfwd=config.use_pushfwd, train_steps_per_cycle=config.train_steps_per_cycle, time_steps_increase_per_cycle=config.time_steps_increase_per_cycle, ) diff --git a/swirl_dynamics/projects/ergodic/measure_distances.py b/swirl_dynamics/projects/ergodic/measure_distances.py index 690e2db..466f51c 100644 --- a/swirl_dynamics/projects/ergodic/measure_distances.py +++ b/swirl_dynamics/projects/ergodic/measure_distances.py @@ -63,19 +63,23 @@ def mmd(x: Array, y: Array) -> Array: rx = jnp.broadcast_to(jnp.expand_dims(jnp.diag(xx), axis=0), xx.shape) ry = jnp.broadcast_to(jnp.expand_dims(jnp.diag(yy), axis=0), yy.shape) - dxx = rx.T + rx - 2.0 * xx # Used for A in (1) - dyy = ry.T + ry - 2.0 * yy # Used for B in (1) - dxy = rx.T + ry - 2.0 * zz # Used for C in (1) + dxx = rx.T + rx - 2.0 * xx + dyy = ry.T + ry - 2.0 * yy + dxy = rx.T + ry - 2.0 * zz xx, yy, xy = (jnp.zeros_like(xx), jnp.zeros_like(xx), jnp.zeros_like(xx)) # Multiscale + # TODO(yairschiff): We may need to experiment with these bandwidths to have + # MMD loss better distinguish distributions, especially for high dim data bandwidth_range = [0.2, 0.5, 0.9, 1.3] for a in bandwidth_range: xx += a**2 * (a**2 + dxx) ** -1 yy += a**2 * (a**2 + dyy) ** -1 xy += a**2 * (a**2 + dxy) ** -1 + # TODO(yairschiff): We may want to use jnp.sqrt(...) here; see: + # https://arxiv.org/abs/1502.02761 return jnp.mean(xx + yy - 2.0 * xy) diff --git a/swirl_dynamics/projects/ergodic/ns_2d.py b/swirl_dynamics/projects/ergodic/ns_2d.py index a9793e5..52c3482 100644 --- a/swirl_dynamics/projects/ergodic/ns_2d.py +++ b/swirl_dynamics/projects/ergodic/ns_2d.py @@ -26,7 +26,7 @@ def plot_trajectories( - dt, traj_lengths, trajs, pred_trajs, case_ids=(11, 32, 67, 89) + dt, traj_lengths, trajs, pred_trajs, case_ids=(1, 3, 5, 7) ): """Plot sample trajectories.""" assert trajs.shape[0] > max(case_ids), ( @@ -62,6 +62,11 @@ def plot_trajectories( class NS2dPlotFigures(stable_ar.PlotFigures): """Navier Stokes 2D plotting.""" + def __init__(self, cos_sim_plot_steps: int = 200): + super().__init__() + # Correlation breaks down early, do not need all the steps + self.cos_sim_plot_steps = cos_sim_plot_steps + def on_eval_batches_end( self, trainer: callbacks.Trainer, eval_metrics: Mapping[str, Array] ) -> None: @@ -84,7 +89,7 @@ def on_eval_batches_end( figs.update( utils.plot_cos_sims( dt=dt, - traj_length=traj_length, + traj_length=min(traj_length, self.cos_sim_plot_steps), trajs=eval_metrics["all_trajs"]["trajs"], pred_trajs=eval_metrics["all_trajs"]["pred_trajs"], ) diff --git a/swirl_dynamics/projects/ergodic/rollout_weighting.py b/swirl_dynamics/projects/ergodic/rollout_weighting.py new file mode 100644 index 0000000..a4d81b2 --- /dev/null +++ b/swirl_dynamics/projects/ergodic/rollout_weighting.py @@ -0,0 +1,93 @@ +# Copyright 2023 The swirl_dynamics Authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Different strategies for downweighting loss from rolled out steps. + +For each method, the input `num_time_steps` (int) corresponds to the number of +steps to be included in a batch, where the first time step corresponds to the +initial condition. Thus `num_time_steps = k + 1`, where `k` is number of rollout +steps. +""" + +import jax +from jax import numpy as jnp + +Array = jax.Array + + +def geometric( + num_time_steps: int, r: float = 0.1, clip: float = 10e-4 +) -> Array: + """Decay loss contribution as `loss * r^(k-1)`, where `r < 1`. + + Args: + num_time_steps: steps to be included in a batch, i.e., rollout steps + 1 + r: geometric weight. + clip: minimum weight. + + Returns: + Rollout weights array. + """ + assert r < 1, f"Geometric decay factor `r` ({r}) should be less than 1." + assert clip > 0, f"Minimum weight `clip` ({clip}) should be greater than 0." + return jnp.clip(r ** jnp.arange(0, num_time_steps - 1), a_min=clip) + + +def inverse_sqrt(num_time_steps: int, clip: float = 10e-4) -> Array: + """Decay loss contribution as `loss * 1 / sqrt(k)`. + + Args: + num_time_steps: steps to be included in a batch, i.e., rollout steps + 1 + clip: minimum weight. + + Returns: + Rollout weights array. + """ + assert clip > 0, f"Minimum weight `clip` ({clip}) should be greater than 0." + return jnp.clip(jnp.arange(1, num_time_steps) ** -0.5, a_min=clip) + + +def inverse_squared(num_time_steps: int, clip: float = 10e-4) -> Array: + """Decay loss contribution as `loss * 1 / (k^2)`. + + Args: + num_time_steps: steps to be included in a batch, i.e., rollout steps + 1 + clip: minimum weight. + + Returns: + Rollout weights array. + """ + assert clip > 0, f"Minimum weight `clip` ({clip}) should be greater than 0." + return jnp.clip(jnp.arange(1, num_time_steps) ** -2.0, a_min=clip) + + +def linear(num_time_steps: int, m: float = 1.0, clip: float = 10e-4) -> Array: + """Decay loss contribution as `loss * 1 / m*k`. + + Args: + num_time_steps: steps to be included in a batch, i.e., rollout steps + 1 + m: slope of decay. + clip: minimum weight. + + Returns: + Rollout weights array. + """ + assert m > 0, f"Linear decay factor `m` ({m}) should be greater than 0." + assert clip > 0, f"Minimum weight `clip` ({clip}) should be greater than 0." + return jnp.clip((m * jnp.arange(1, num_time_steps)) ** -1.0, a_min=clip) + + +def no_weight(num_time_steps: int) -> Array: + """No decay. All steps contribute equally.""" + return jnp.ones(num_time_steps - 1) diff --git a/swirl_dynamics/projects/ergodic/stable_ar.py b/swirl_dynamics/projects/ergodic/stable_ar.py index 1583a8e..7a7efca 100644 --- a/swirl_dynamics/projects/ergodic/stable_ar.py +++ b/swirl_dynamics/projects/ergodic/stable_ar.py @@ -96,6 +96,7 @@ def loss_fn( # When using data parallelism, it will add an extra dimension due to the # pmap_reshape, so this line is to avoid shape mismatches. tspan = batch["tspan"].reshape((-1,)) + rollout_weight = batch["rollout_weight"].reshape((-1,)) # TODO(lzepedanunez): implement the logic in the Neural Markov paper. if self.conf.add_noise: @@ -122,11 +123,16 @@ def loss_fn( )[:, -1, ...] # Computing losses. - measure_dist = self.conf.measure_dist(pred, true[:, 0, ...]) - measure_dist_k = self.conf.measure_dist(pred, true[:, -1, ...]) + measure_dist = ( + self.conf.measure_dist(pred, true[:, 0, ...]) * rollout_weight[-1] + ) + measure_dist_k = ( + self.conf.measure_dist(pred, true[:, -1, ...]) * rollout_weight[-1] + ) # Compare to true trajectory last step. if self.conf.use_sobolev_norm: + # TODO(yairschiff): Rollout weighting not implemented for this case! # The spatial dimension is the length of the shape minus 2, # which accounts for the batch, frame, and channel dimensions. dim = len(pred.shape) - 2 @@ -134,32 +140,45 @@ def loss_fn( pred - true[:, -1, ...], s=self.conf.order_sobolev_norm, dim=dim ) else: - l2 = jnp.mean(jnp.square(pred - true[:, -1, ...])) + l2 = jnp.mean( + jnp.square(pred - true[:, -1, ...]).mean( + axis=tuple(range(1, pred.ndim)) + ) + * rollout_weight[-1] + ) else: # Regular unrolling without stop-gradient # Expected shape: (bsz, num_rollout_steps, ...) - pred = self.pred_integrator( - x0, tspan, dict(params=params, **mutables) - )[:, self.conf.num_lookback_steps :, ...] + pred = self.pred_integrator(x0, tspan, dict(params=params, **mutables))[ + :, self.conf.num_lookback_steps :, ... + ] measure_dist = jnp.mean( jax.vmap( lambda p: self.conf.measure_dist(p, true[:, 0, ...]), in_axes=(1), - )(pred) + )(pred) * rollout_weight ) measure_dist_k = jnp.mean( - self.vmapped_measure_dist(pred, true[:, 1:, ...]) + self.vmapped_measure_dist(pred, true[:, 1:, ...]) * rollout_weight ) # Compare to full reference trajectory. # TODO(lzepedanunez): this is code is repeated. if self.conf.use_sobolev_norm: + # TODO(yairschiff): Rollout weighting not implemented for this case! dim = len(pred.shape) - 3 l2 = ergodic_utils.sobolev_norm( - pred - true[:, 1:, ...], s=self.conf.order_sobolev_norm, dim=dim + pred - true[:, 1:, ...], + s=self.conf.order_sobolev_norm, + dim=dim, ) else: - l2 = jnp.mean(jnp.square(pred - true[:, 1:, ...])) + l2 = jnp.mean( + jnp.square(pred - true[:, 1:, ...]).mean( + axis=tuple(range(2, pred.ndim)) + ) + * rollout_weight + ) # Gathering the metrics together. loss = l2 @@ -172,6 +191,7 @@ def loss_fn( measure_dist=measure_dist, measure_dist_k=measure_dist_k, rollout=jnp.array(tspan.shape[0] - 1), + max_rollout_decay=rollout_weight[-1], ) return loss, (metric, mutables) @@ -220,6 +240,7 @@ def eval_fn( class StableARTrainerConfig: """Config used by stable AR trainers.""" + rollout_weighting: choices.RolloutWeightingFn num_rollout_steps: int = 1 num_lookback_steps: int = 1 add_noise: bool = False @@ -243,6 +264,7 @@ class TrainMetrics(clu_metrics.Collection): measure_dist_k: clu_metrics.Average.from_output("measure_dist_k") measure_dist_k_std: clu_metrics.Average.from_output("measure_dist_k") rollout: clu_metrics.Average.from_output("rollout") + max_rollout_decay: clu_metrics.Average.from_output("max_rollout_decay") @flax.struct.dataclass class EvalMetrics(clu_metrics.Collection): @@ -275,6 +297,7 @@ def _preprocess_train_batch( dt = jnp.mean(jnp.diff(batch_data["t"], axis=1)) tspan = jnp.arange(num_time_steps) * dt + rollout_weight = self.conf.rollout_weighting(num_time_steps) # `x0`: first "state" (which can be `num_lookback_steps` time steps). # `true`: num_rollout_steps + 1 states (where the first state corresponds to # x0, except when num_lookback_steps > 1, where true[:, 0] corresponds to @@ -305,6 +328,7 @@ def _preprocess_train_batch( x0=x0, true=true, tspan=tspan, + rollout_weight=rollout_weight, ) def preprocess_train_batch( @@ -336,8 +360,8 @@ def preprocess_train_batch( )[0] # pytype: disable=attribute-error assert num_time_steps <= batch_data["u"].shape[1], ( - f"Not enough time steps in data ({batch_data.shape[1]}) for desired" - f" steps ({num_time_steps})." + f"Not enough time steps in data ({batch_data['u'].shape[1]}) for" + f" desired steps ({num_time_steps})." ) # pytype: enable=attribute-error return self._preprocess_train_batch(batch_data, num_time_steps) @@ -373,6 +397,7 @@ class TrainMetrics(clu_metrics.Collection): measure_dist_k: clu_metrics.Average.from_output("measure_dist_k") measure_dist_k_std: clu_metrics.Average.from_output("measure_dist_k") rollout: clu_metrics.Average.from_output("rollout") + max_rollout_decay: clu_metrics.Average.from_output("max_rollout_decay") @flax.struct.dataclass class EvalMetrics(clu_metrics.Collection): @@ -413,10 +438,7 @@ def _preprocess_train_batch( x0 = batch_data["u"][:, : self.conf.num_lookback_steps, ...] true = batch_data["u"][ :, - self.conf.num_lookback_steps - - 1 : num_time_steps - + self.conf.num_lookback_steps - - 1, # pylint: disable=line-too-long + self.conf.num_lookback_steps - 1 : num_time_steps + self.conf.num_lookback_steps - 1, # pylint: disable=line-too-long ..., ] else: