diff --git a/swirl_dynamics/projects/ergodic/colabs/demo.ipynb b/swirl_dynamics/projects/ergodic/colabs/demo.ipynb index b18abf0..bf9717d 100644 --- a/swirl_dynamics/projects/ergodic/colabs/demo.ipynb +++ b/swirl_dynamics/projects/ergodic/colabs/demo.ipynb @@ -128,15 +128,15 @@ }, "outputs": [], "source": [ - "experiment = \"lorenz63\" #@param choices=['lorenz63', 'ks_1d', 'ns_2d']\n", + "experiment = \"ks_1d\" #@param choices=['lorenz63', 'ks_1d', 'ns_2d']\n", "batch_size = 50 #@param {type:\"integer\"}\n", "measure_dist_type = \"MMD\" #@param choices=['MMD', 'SD']\n", - "normalize = True #@param {type:\"boolean\"}\n", + "normalize = False #@param {type:\"boolean\"}\n", "add_noise = False #@param {type:\"boolean\"}\n", "use_curriculum = False #@param {type:\"boolean\"}\n", "use_pushfwd = False #@param {type:\"boolean\"}\n", - "measure_dist_lambda = 1.0 #@param {type:\"number\"}\n", - "measure_dist_k_lambda = 10.0 #@param {type:\"number\"}\n", + "measure_dist_lambda = 0.0 #@param {type:\"number\"}\n", + "measure_dist_k_lambda = 0.0 #@param {type:\"number\"}\n", "display_config = True #@param {type:\"boolean\"}\n", "config = get_config(\n", " experiment,\n", @@ -179,17 +179,6 @@ "workdir = \"\u003cTODO: INSERT WORKDIR HERE\u003e\" #@param\n", ] }, - { - "cell_type": "code", - "execution_count": null, - "metadata": { - "id": "GxHm2IzCqVgq" - }, - "outputs": [], - "source": [ - "config.save_interval_steps = 20" - ] - }, { "cell_type": "code", "execution_count": null, @@ -334,7 +323,7 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "0Um9nKZhcf6K" + "id": "CCBlVwYaKrhy" }, "outputs": [], "source": [] @@ -346,7 +335,7 @@ "d_9f5Fwifhd9" ], "last_runtime": { - "build_target": "//learning/deepmind/dm_python:dm_notebook3_tpu", + "build_target": "//learning/grp/tools/ml_python:ml_notebook", "kind": "private" }, "private_outputs": true, diff --git a/swirl_dynamics/projects/ergodic/configs/ks_1d.py b/swirl_dynamics/projects/ergodic/configs/ks_1d.py index 68f34f7..fe04a6d 100644 --- a/swirl_dynamics/projects/ergodic/configs/ks_1d.py +++ b/swirl_dynamics/projects/ergodic/configs/ks_1d.py @@ -36,7 +36,7 @@ def get_config(): config.max_checkpoints_to_keep = 10 # Data params config.batch_size = 128 - config.num_time_steps = 11 + config.num_time_steps = 2 config.time_stride = 1 config.dataset_path = DATA_PATH config.spatial_downsample_factor = 1 @@ -61,12 +61,12 @@ def get_config(): config.model = 'FNO' config.out_channels = 1 config.hidden_channels = 64 - config.num_modes = (512,) + config.num_modes = (24,) config.lifting_channels = 256 config.projection_channels = 256 config.num_blocks = 4 config.layers_per_block = 2 - config.block_skip_type = 'soft-gate' + config.block_skip_type = 'identity' config.fft_norm = 'forward' config.separable = False # Update num_time_steps based on num_lookback_steps setting @@ -74,9 +74,9 @@ def get_config(): # Trainer params config.num_rollout_steps = 1 config.train_steps_per_cycle = 50_000 - config.time_steps_increase_per_cycle = 1 - config.use_curriculum = True # Sweepable - config.use_pushfwd = True # Sweepable + config.time_steps_increase_per_cycle = 0 + config.use_curriculum = False # Sweepable + config.use_pushfwd = False # Sweepable config.measure_dist_type = 'MMD' # Sweepable config.measure_dist_downsample = 1 config.measure_dist_lambda = 0.0 # Sweepable diff --git a/swirl_dynamics/projects/ergodic/configs/lorenz63.py b/swirl_dynamics/projects/ergodic/configs/lorenz63.py index 1d333ed..b9a0a91 100644 --- a/swirl_dynamics/projects/ergodic/configs/lorenz63.py +++ b/swirl_dynamics/projects/ergodic/configs/lorenz63.py @@ -19,7 +19,7 @@ import ml_collections # pylint: disable=line-too-long -DATA_PATH = '/datasets/gcs_staging/hdf5/ode/lorenz63_trajectories.hdf5' +DATA_PATH = '/datasets/gcs_staging/hdf5/ode/1d/lorenz63_trajectories.hdf5' # pylint: enable=line-too-long @@ -53,13 +53,13 @@ def get_config(): # Trainer params config.num_rollout_steps = 1 config.train_steps_per_cycle = 50_000 - config.time_steps_increase_per_cycle = 1 - config.use_curriculum = True # Sweepable - config.use_pushfwd = True # Sweepable - config.measure_dist_type = 'SD' # Sweepable + config.time_steps_increase_per_cycle = 0 + config.use_curriculum = False # Sweepable + config.use_pushfwd = False # Sweepable + config.measure_dist_type = 'MMD' # Sweepable config.measure_dist_downsample = 1 - config.measure_dist_lambda = 1.0 # Sweepable - config.measure_dist_k_lambda = 1000.0 # Sweepable + config.measure_dist_lambda = 0.0 # Sweepable + config.measure_dist_k_lambda = 0.0 # Sweepable return config @@ -86,9 +86,9 @@ def skip( # TODO(yairschiff): Refactor sweeps and experiment definition to use gin. def sweep(add): """Define param sweep.""" - for seed in [1, 11, 21, 42, 84]: + for seed in [42]: for measure_dist_type in ['MMD', 'SD']: - for batch_size in [4096]: + for batch_size in [2048]: for use_curriculum in [False, True]: for use_pushfwd in [False, True]: for measure_dist_lambda in [0.0, 1.0, 10.0, 100.0, 1000.0]: diff --git a/swirl_dynamics/projects/ergodic/configs/ns_2d.py b/swirl_dynamics/projects/ergodic/configs/ns_2d.py index 64688ba..b2eae2c 100644 --- a/swirl_dynamics/projects/ergodic/configs/ns_2d.py +++ b/swirl_dynamics/projects/ergodic/configs/ns_2d.py @@ -41,7 +41,7 @@ def get_config(): config.time_stride = 1 config.dataset_path = DATA_PATH config.spatial_downsample_factor = 1 - config.normalize = True + config.normalize = False config.add_noise = False config.noise_level = 0.0 # Model params @@ -67,7 +67,7 @@ def get_config(): config.projection_channels = 256 config.num_blocks = 4 config.layers_per_block = 2 - config.block_skip_type = 'soft-gate' + config.block_skip_type = 'identity' config.fft_norm = 'forward' config.separable = False # Update num_time_steps based on num_lookback_steps setting @@ -90,8 +90,8 @@ def sweep(add): """Define param sweep.""" for seed in [42]: for measure_dist_type in ['MMD', 'SD']: - for measure_dist_k_lambda in [0.0, 1.0, 10.0]: - for measure_dist_lambda in [0.0, 1.0]: + for measure_dist_k_lambda in [100.0, 1000.0]: + for measure_dist_lambda in [0.0]: for measure_dist_downsample in [1, 2]: if measure_dist_k_lambda == measure_dist_lambda == 0.0: if measure_dist_type == 'SD' or measure_dist_downsample > 1: diff --git a/swirl_dynamics/projects/ergodic/ks_1d.py b/swirl_dynamics/projects/ergodic/ks_1d.py index 506c0be..989db20 100644 --- a/swirl_dynamics/projects/ergodic/ks_1d.py +++ b/swirl_dynamics/projects/ergodic/ks_1d.py @@ -115,6 +115,14 @@ def on_eval_batches_end( pred_trajs=eval_metrics["all_trajs"]["pred_trajs"], ) ) + figs.update( + utils.plot_cos_sims( + dt=dt, + traj_length=traj_length, + trajs=eval_metrics["all_trajs"]["trajs"], + pred_trajs=eval_metrics["all_trajs"]["pred_trajs"], + ) + ) figs_as_images = { k: self.figure_to_image(v).transpose(1, 2, 0) for k, v in figs.items() } diff --git a/swirl_dynamics/projects/ergodic/lorenz63.py b/swirl_dynamics/projects/ergodic/lorenz63.py index 457f450..a3d4cd8 100644 --- a/swirl_dynamics/projects/ergodic/lorenz63.py +++ b/swirl_dynamics/projects/ergodic/lorenz63.py @@ -174,6 +174,38 @@ def plot_trajectory_hists(dt, traj_lengths, trajs, pred_trajs): return {"traj_hists": fig} +def plot_correlations(dt, traj_length, trajs, pred_trajs): + """Plots coordinate-wise correlations over time.""" + fig, ax = plt.subplots( + nrows=1, ncols=3, sharey=True, figsize=(17, 5), tight_layout=True + ) + ax[0].set_ylabel("Corr. Coeff.") + fig.suptitle("Correlation: ground truth and pred. trajectories across time") + for d, n in zip(range(3), ["x", "y", "z"]): + ax[d].plot( + jnp.arange(traj_length) * dt, + jnp.ones(traj_length)*0.9, + color="black", linestyle="dashed", + label="0.9 threshold" + ) + ax[d].plot( + jnp.arange(traj_length) * dt, + jnp.ones(traj_length)*0.8, + color="red", linestyle="dashed", + label="0.8 threshold" + ) + ax[d].set_xlim(0, traj_length*dt) + ax[d].set_xlabel("t") + ax[d].set_title(n) + for d in range(3): + corrs = jax.vmap(jnp.corrcoef, in_axes=(1, 1))( + trajs[:, :traj_length, d], pred_trajs[:, :traj_length, d] + )[:, 1, 0] + ax[d].plot(jnp.arange(traj_length) * dt, corrs) + ax[-1].legend(frameon=False, bbox_to_anchor=(1, 1)) + return {"corr": fig} + + class Lorenz63PlotFigures(stable_ar.PlotFigures): """Lorenz 63 plotting.""" @@ -196,6 +228,15 @@ def on_eval_batches_end( pred_trajs=eval_metrics["all_trajs"]["pred_trajs"], ) ) + figs.update( + plot_correlations( + dt=dt, + # Correlation breaks down early, do not need all the steps + traj_length=min(traj_length, 20), + trajs=eval_metrics["all_trajs"]["trajs"], + pred_trajs=eval_metrics["all_trajs"]["pred_trajs"], + ) + ) figs_as_images = { k: self.figure_to_image(v).transpose(1, 2, 0) for k, v in figs.items() } diff --git a/swirl_dynamics/projects/ergodic/ns_2d.py b/swirl_dynamics/projects/ergodic/ns_2d.py index 00918cb..c1bbf19 100644 --- a/swirl_dynamics/projects/ergodic/ns_2d.py +++ b/swirl_dynamics/projects/ergodic/ns_2d.py @@ -74,6 +74,14 @@ def on_eval_batches_end( pred_trajs=eval_metrics["all_trajs"]["pred_trajs"], ) ) + figs.update( + utils.plot_cos_sims( + dt=dt, + traj_length=traj_length, + trajs=eval_metrics["all_trajs"]["trajs"], + pred_trajs=eval_metrics["all_trajs"]["pred_trajs"], + ) + ) figs_as_images = { k: self.figure_to_image(v).transpose(1, 2, 0) for k, v in figs.items() } diff --git a/swirl_dynamics/projects/ergodic/utils.py b/swirl_dynamics/projects/ergodic/utils.py index f6b2a01..bf79388 100644 --- a/swirl_dynamics/projects/ergodic/utils.py +++ b/swirl_dynamics/projects/ergodic/utils.py @@ -320,3 +320,62 @@ def linear_scale_dissipative_target(inputs: Array, scale: float = 1.0): Operator (MNO). """ return scale * inputs + + +def plot_cos_sims(dt: Array, traj_length: int, trajs: Array, pred_trajs: Array): + """Plot cosine similarities over time.""" + def sum_non_batch_dims(x: Array) -> Array: + """Helper method to sum array along all dimensions except the 0th.""" + ndim = x.ndim + return x.sum(axis=tuple(range(1, ndim))) + + def state_cos_sim(x: Array, y: Array) -> Array: + """Compute cosine similiarity between two batches of states. + + Computes x^Ty / ||x||*||y|| averaged across batch dimension (axis = 0). + + Args: + x: array of states; shape: batch_size x state_dimension + y: array of states; shape: batch_size x state_dimension + Returns: + cosine similarity averaged along batch dimension. + """ + x_norm = jnp.expand_dims( + jnp.sqrt(sum_non_batch_dims((x ** 2))), + axis=tuple(range(1, x.ndim)) + ) + x /= x_norm + y_norm = jnp.expand_dims( + jnp.sqrt(sum_non_batch_dims((y ** 2))), + axis=tuple(range(1, y.ndim)) + ) + y /= y_norm + return sum_non_batch_dims(x * y).mean(axis=0) + + plot_time = jnp.arange(traj_length) * dt + t_max = plot_time.max() + fig, ax = plt.subplots(1, 1, figsize=(7, 4), tight_layout=True) + # Plot 0.9, 0.8 threshold lines + ax.plot( + plot_time, + jnp.ones(traj_length)*0.9, + color="black", linestyle="dashed", + label="0.9 threshold" + ) + ax.plot( + plot_time, + jnp.ones(traj_length)*0.8, + color="red", linestyle="dashed", + label="0.8 threshold" + ) + # Plot correlation lines + cosine_sims = jax.vmap( + state_cos_sim, in_axes=(1, 1) + )(trajs[:, :traj_length, :], pred_trajs[:, :traj_length, :]) + ax.plot(plot_time, cosine_sims) + ax.set_xlim([0, t_max]) + ax.set_xlabel(r"$t$") + ax.set_ylabel("Avg. cosine sim.") + ax.set_title("Cosine Similiarity over time") + ax.legend(frameon=False, bbox_to_anchor=(1, 1)) + return {"cosine_sim": fig}