Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 562971611
  • Loading branch information
The swirl_dynamics Authors committed Sep 6, 2023
1 parent 1fe9d36 commit a16969f
Show file tree
Hide file tree
Showing 8 changed files with 141 additions and 36 deletions.
23 changes: 6 additions & 17 deletions swirl_dynamics/projects/ergodic/colabs/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -128,15 +128,15 @@
},
"outputs": [],
"source": [
"experiment = \"lorenz63\" #@param choices=['lorenz63', 'ks_1d', 'ns_2d']\n",
"experiment = \"ks_1d\" #@param choices=['lorenz63', 'ks_1d', 'ns_2d']\n",
"batch_size = 50 #@param {type:\"integer\"}\n",
"measure_dist_type = \"MMD\" #@param choices=['MMD', 'SD']\n",
"normalize = True #@param {type:\"boolean\"}\n",
"normalize = False #@param {type:\"boolean\"}\n",
"add_noise = False #@param {type:\"boolean\"}\n",
"use_curriculum = False #@param {type:\"boolean\"}\n",
"use_pushfwd = False #@param {type:\"boolean\"}\n",
"measure_dist_lambda = 1.0 #@param {type:\"number\"}\n",
"measure_dist_k_lambda = 10.0 #@param {type:\"number\"}\n",
"measure_dist_lambda = 0.0 #@param {type:\"number\"}\n",
"measure_dist_k_lambda = 0.0 #@param {type:\"number\"}\n",
"display_config = True #@param {type:\"boolean\"}\n",
"config = get_config(\n",
" experiment,\n",
Expand Down Expand Up @@ -179,17 +179,6 @@
"workdir = \"\u003cTODO: INSERT WORKDIR HERE\u003e\" #@param\n",
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "GxHm2IzCqVgq"
},
"outputs": [],
"source": [
"config.save_interval_steps = 20"
]
},
{
"cell_type": "code",
"execution_count": null,
Expand Down Expand Up @@ -334,7 +323,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "0Um9nKZhcf6K"
"id": "CCBlVwYaKrhy"
},
"outputs": [],
"source": []
Expand All @@ -346,7 +335,7 @@
"d_9f5Fwifhd9"
],
"last_runtime": {
"build_target": "//learning/deepmind/dm_python:dm_notebook3_tpu",
"build_target": "//learning/grp/tools/ml_python:ml_notebook",
"kind": "private"
},
"private_outputs": true,
Expand Down
12 changes: 6 additions & 6 deletions swirl_dynamics/projects/ergodic/configs/ks_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ def get_config():
config.max_checkpoints_to_keep = 10
# Data params
config.batch_size = 128
config.num_time_steps = 11
config.num_time_steps = 2
config.time_stride = 1
config.dataset_path = DATA_PATH
config.spatial_downsample_factor = 1
Expand All @@ -61,22 +61,22 @@ def get_config():
config.model = 'FNO'
config.out_channels = 1
config.hidden_channels = 64
config.num_modes = (512,)
config.num_modes = (24,)
config.lifting_channels = 256
config.projection_channels = 256
config.num_blocks = 4
config.layers_per_block = 2
config.block_skip_type = 'soft-gate'
config.block_skip_type = 'identity'
config.fft_norm = 'forward'
config.separable = False
# Update num_time_steps based on num_lookback_steps setting
config.num_time_steps += config.num_lookback_steps - 1
# Trainer params
config.num_rollout_steps = 1
config.train_steps_per_cycle = 50_000
config.time_steps_increase_per_cycle = 1
config.use_curriculum = True # Sweepable
config.use_pushfwd = True # Sweepable
config.time_steps_increase_per_cycle = 0
config.use_curriculum = False # Sweepable
config.use_pushfwd = False # Sweepable
config.measure_dist_type = 'MMD' # Sweepable
config.measure_dist_downsample = 1
config.measure_dist_lambda = 0.0 # Sweepable
Expand Down
18 changes: 9 additions & 9 deletions swirl_dynamics/projects/ergodic/configs/lorenz63.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import ml_collections

# pylint: disable=line-too-long
DATA_PATH = '/datasets/gcs_staging/hdf5/ode/lorenz63_trajectories.hdf5'
DATA_PATH = '/datasets/gcs_staging/hdf5/ode/1d/lorenz63_trajectories.hdf5'
# pylint: enable=line-too-long


Expand Down Expand Up @@ -53,13 +53,13 @@ def get_config():
# Trainer params
config.num_rollout_steps = 1
config.train_steps_per_cycle = 50_000
config.time_steps_increase_per_cycle = 1
config.use_curriculum = True # Sweepable
config.use_pushfwd = True # Sweepable
config.measure_dist_type = 'SD' # Sweepable
config.time_steps_increase_per_cycle = 0
config.use_curriculum = False # Sweepable
config.use_pushfwd = False # Sweepable
config.measure_dist_type = 'MMD' # Sweepable
config.measure_dist_downsample = 1
config.measure_dist_lambda = 1.0 # Sweepable
config.measure_dist_k_lambda = 1000.0 # Sweepable
config.measure_dist_lambda = 0.0 # Sweepable
config.measure_dist_k_lambda = 0.0 # Sweepable
return config


Expand All @@ -86,9 +86,9 @@ def skip(
# TODO(yairschiff): Refactor sweeps and experiment definition to use gin.
def sweep(add):
"""Define param sweep."""
for seed in [1, 11, 21, 42, 84]:
for seed in [42]:
for measure_dist_type in ['MMD', 'SD']:
for batch_size in [4096]:
for batch_size in [2048]:
for use_curriculum in [False, True]:
for use_pushfwd in [False, True]:
for measure_dist_lambda in [0.0, 1.0, 10.0, 100.0, 1000.0]:
Expand Down
8 changes: 4 additions & 4 deletions swirl_dynamics/projects/ergodic/configs/ns_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ def get_config():
config.time_stride = 1
config.dataset_path = DATA_PATH
config.spatial_downsample_factor = 1
config.normalize = True
config.normalize = False
config.add_noise = False
config.noise_level = 0.0
# Model params
Expand All @@ -67,7 +67,7 @@ def get_config():
config.projection_channels = 256
config.num_blocks = 4
config.layers_per_block = 2
config.block_skip_type = 'soft-gate'
config.block_skip_type = 'identity'
config.fft_norm = 'forward'
config.separable = False
# Update num_time_steps based on num_lookback_steps setting
Expand All @@ -90,8 +90,8 @@ def sweep(add):
"""Define param sweep."""
for seed in [42]:
for measure_dist_type in ['MMD', 'SD']:
for measure_dist_k_lambda in [0.0, 1.0, 10.0]:
for measure_dist_lambda in [0.0, 1.0]:
for measure_dist_k_lambda in [100.0, 1000.0]:
for measure_dist_lambda in [0.0]:
for measure_dist_downsample in [1, 2]:
if measure_dist_k_lambda == measure_dist_lambda == 0.0:
if measure_dist_type == 'SD' or measure_dist_downsample > 1:
Expand Down
8 changes: 8 additions & 0 deletions swirl_dynamics/projects/ergodic/ks_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,14 @@ def on_eval_batches_end(
pred_trajs=eval_metrics["all_trajs"]["pred_trajs"],
)
)
figs.update(
utils.plot_cos_sims(
dt=dt,
traj_length=traj_length,
trajs=eval_metrics["all_trajs"]["trajs"],
pred_trajs=eval_metrics["all_trajs"]["pred_trajs"],
)
)
figs_as_images = {
k: self.figure_to_image(v).transpose(1, 2, 0) for k, v in figs.items()
}
Expand Down
41 changes: 41 additions & 0 deletions swirl_dynamics/projects/ergodic/lorenz63.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,38 @@ def plot_trajectory_hists(dt, traj_lengths, trajs, pred_trajs):
return {"traj_hists": fig}


def plot_correlations(dt, traj_length, trajs, pred_trajs):
"""Plots coordinate-wise correlations over time."""
fig, ax = plt.subplots(
nrows=1, ncols=3, sharey=True, figsize=(17, 5), tight_layout=True
)
ax[0].set_ylabel("Corr. Coeff.")
fig.suptitle("Correlation: ground truth and pred. trajectories across time")
for d, n in zip(range(3), ["x", "y", "z"]):
ax[d].plot(
jnp.arange(traj_length) * dt,
jnp.ones(traj_length)*0.9,
color="black", linestyle="dashed",
label="0.9 threshold"
)
ax[d].plot(
jnp.arange(traj_length) * dt,
jnp.ones(traj_length)*0.8,
color="red", linestyle="dashed",
label="0.8 threshold"
)
ax[d].set_xlim(0, traj_length*dt)
ax[d].set_xlabel("t")
ax[d].set_title(n)
for d in range(3):
corrs = jax.vmap(jnp.corrcoef, in_axes=(1, 1))(
trajs[:, :traj_length, d], pred_trajs[:, :traj_length, d]
)[:, 1, 0]
ax[d].plot(jnp.arange(traj_length) * dt, corrs)
ax[-1].legend(frameon=False, bbox_to_anchor=(1, 1))
return {"corr": fig}


class Lorenz63PlotFigures(stable_ar.PlotFigures):
"""Lorenz 63 plotting."""

Expand All @@ -196,6 +228,15 @@ def on_eval_batches_end(
pred_trajs=eval_metrics["all_trajs"]["pred_trajs"],
)
)
figs.update(
plot_correlations(
dt=dt,
# Correlation breaks down early, do not need all the steps
traj_length=min(traj_length, 20),
trajs=eval_metrics["all_trajs"]["trajs"],
pred_trajs=eval_metrics["all_trajs"]["pred_trajs"],
)
)
figs_as_images = {
k: self.figure_to_image(v).transpose(1, 2, 0) for k, v in figs.items()
}
Expand Down
8 changes: 8 additions & 0 deletions swirl_dynamics/projects/ergodic/ns_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,14 @@ def on_eval_batches_end(
pred_trajs=eval_metrics["all_trajs"]["pred_trajs"],
)
)
figs.update(
utils.plot_cos_sims(
dt=dt,
traj_length=traj_length,
trajs=eval_metrics["all_trajs"]["trajs"],
pred_trajs=eval_metrics["all_trajs"]["pred_trajs"],
)
)
figs_as_images = {
k: self.figure_to_image(v).transpose(1, 2, 0) for k, v in figs.items()
}
Expand Down
59 changes: 59 additions & 0 deletions swirl_dynamics/projects/ergodic/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,3 +320,62 @@ def linear_scale_dissipative_target(inputs: Array, scale: float = 1.0):
Operator (MNO).
"""
return scale * inputs


def plot_cos_sims(dt: Array, traj_length: int, trajs: Array, pred_trajs: Array):
"""Plot cosine similarities over time."""
def sum_non_batch_dims(x: Array) -> Array:
"""Helper method to sum array along all dimensions except the 0th."""
ndim = x.ndim
return x.sum(axis=tuple(range(1, ndim)))

def state_cos_sim(x: Array, y: Array) -> Array:
"""Compute cosine similiarity between two batches of states.
Computes x^Ty / ||x||*||y|| averaged across batch dimension (axis = 0).
Args:
x: array of states; shape: batch_size x state_dimension
y: array of states; shape: batch_size x state_dimension
Returns:
cosine similarity averaged along batch dimension.
"""
x_norm = jnp.expand_dims(
jnp.sqrt(sum_non_batch_dims((x ** 2))),
axis=tuple(range(1, x.ndim))
)
x /= x_norm
y_norm = jnp.expand_dims(
jnp.sqrt(sum_non_batch_dims((y ** 2))),
axis=tuple(range(1, y.ndim))
)
y /= y_norm
return sum_non_batch_dims(x * y).mean(axis=0)

plot_time = jnp.arange(traj_length) * dt
t_max = plot_time.max()
fig, ax = plt.subplots(1, 1, figsize=(7, 4), tight_layout=True)
# Plot 0.9, 0.8 threshold lines
ax.plot(
plot_time,
jnp.ones(traj_length)*0.9,
color="black", linestyle="dashed",
label="0.9 threshold"
)
ax.plot(
plot_time,
jnp.ones(traj_length)*0.8,
color="red", linestyle="dashed",
label="0.8 threshold"
)
# Plot correlation lines
cosine_sims = jax.vmap(
state_cos_sim, in_axes=(1, 1)
)(trajs[:, :traj_length, :], pred_trajs[:, :traj_length, :])
ax.plot(plot_time, cosine_sims)
ax.set_xlim([0, t_max])
ax.set_xlabel(r"$t$")
ax.set_ylabel("Avg. cosine sim.")
ax.set_title("Cosine Similiarity over time")
ax.legend(frameon=False, bbox_to_anchor=(1, 1))
return {"cosine_sim": fig}

0 comments on commit a16969f

Please sign in to comment.