Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 571432015
  • Loading branch information
The swirl_dynamics Authors committed Oct 6, 2023
1 parent a2ac7d6 commit 05dab19
Show file tree
Hide file tree
Showing 13 changed files with 597 additions and 209 deletions.
42 changes: 41 additions & 1 deletion swirl_dynamics/projects/ergodic/choices.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,12 @@
from swirl_dynamics.lib.networks import nonlinear_fourier
from swirl_dynamics.lib.solvers import ode
from swirl_dynamics.projects.ergodic import measure_distances
from swirl_dynamics.projects.ergodic import rollout_weighting


Array = jax.Array
MeasureDistFn = Callable[[Array, Array], float | Array]
RolloutWeightingFn = Callable[[int], Array]


class Experiment(enum.Enum):
Expand Down Expand Up @@ -156,6 +158,44 @@ def dispatch(self, conf: ml_collections.ConfigDict) -> nn.Module:
out_channels=conf.out_channels,
num_modes=conf.num_modes,
width=conf.width,
fft_norm=conf.fft_norm
fft_norm=conf.fft_norm,
)
raise ValueError()


class RolloutWeighting(enum.Enum):
"""Rollout weighting choices."""

GEOMETRIC = "geometric"
INV_SQRT = "inv_sqrt"
INV_SQUARED = "inv_squared"
LINEAR = "linear"
NO_WEIGHT = "no_weight"

def dispatch(self, conf: ml_collections.ConfigDict) -> RolloutWeightingFn:
"""Dispatch rollout weighting."""
if self.value == RolloutWeighting.GEOMETRIC.value:
return functools.partial(
rollout_weighting.geometric,
r=conf.rollout_weighting_r,
clip=conf.rollout_weighting_clip
)
if self.value == RolloutWeighting.INV_SQRT.value:
return functools.partial(
rollout_weighting.inverse_sqrt,
clip=conf.rollout_weighting_clip
)
if self.value == RolloutWeighting.INV_SQUARED.value:
return functools.partial(
rollout_weighting.inverse_squared,
clip=conf.rollout_weighting_clip
)
if self.value == RolloutWeighting.LINEAR.value:
return functools.partial(
rollout_weighting.linear,
m=conf.rollout_weighting_m,
clip=conf.rollout_weighting_clip
)
if self.value == RolloutWeighting.NO_WEIGHT.value:
return rollout_weighting.no_weight
raise ValueError()
18 changes: 10 additions & 8 deletions swirl_dynamics/projects/ergodic/colabs/demo.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -128,13 +128,13 @@
},
"outputs": [],
"source": [
"experiment = \"ks_1d\" #@param choices=['lorenz63', 'ks_1d', 'ns_2d']\n",
"batch_size = 128 #@param {type:\"integer\"}\n",
"experiment = \"ns_2d\" #@param choices=['lorenz63', 'ks_1d', 'ns_2d']\n",
"batch_size = 512 #@param {type:\"integer\"}\n",
"measure_dist_type = \"MMD\" #@param choices=['MMD', 'SD']\n",
"normalize = False #@param {type:\"boolean\"}\n",
"add_noise = False #@param {type:\"boolean\"}\n",
"use_curriculum = True #@param {type:\"boolean\"}\n",
"use_pushfwd = False #@param {type:\"boolean\"}\n",
"use_pushfwd = True #@param {type:\"boolean\"}\n",
"measure_dist_lambda = 0.0 #@param {type:\"number\"}\n",
"measure_dist_k_lambda = 0.0 #@param {type:\"number\"}\n",
"display_config = True #@param {type:\"boolean\"}\n",
Expand Down Expand Up @@ -183,7 +183,6 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"cellView": "form",
"id": "3oO4PJjqk-6i"
},
"outputs": [],
Expand Down Expand Up @@ -272,6 +271,9 @@
"\n",
"# Trainer\n",
"trainer_config = stable_ar.StableARTrainerConfig(\n",
" rollout_weighting=choices.RolloutWeighting(\n",
" config.rollout_weighting\n",
" ).dispatch(config),\n",
" num_rollout_steps=config.num_rollout_steps,\n",
" num_lookback_steps=config.num_lookback_steps,\n",
" add_noise=config.add_noise,\n",
Expand Down Expand Up @@ -311,7 +313,7 @@
" ),\n",
" callbacks.TqdmProgressBar(\n",
" total_train_steps=config.train_steps,\n",
" train_monitors=[\"rollout\", \"loss\", \"measure_dist\", \"measure_dist_k\"],\n",
" train_monitors=[\"rollout\", \"loss\", \"measure_dist\", \"measure_dist_k\", \"max_rollout_decay\"],\n",
" eval_monitors=[\"sd\"],\n",
" ),\n",
" fig_callback_cls()\n",
Expand All @@ -323,7 +325,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "cxxB4OZBuAA2"
"id": "-8lZXGG4wFks"
},
"outputs": [],
"source": [
Expand All @@ -333,7 +335,7 @@
"cell_type": "code",
"execution_count": null,
"metadata": {
"id": "2AT2YThCyNPt"
"id": "fu6IxOuFhULW"
},
"outputs": [],
"source": []
Expand All @@ -345,7 +347,7 @@
"d_9f5Fwifhd9"
],
"last_runtime": {
"build_target": "//learning/grp/tools/ml_python:ml_notebook",
"build_target": "//learning/deepmind/dm_python:dm_notebook3_tpu",
"kind": "private"
},
"private_outputs": true,
Expand Down
189 changes: 93 additions & 96 deletions swirl_dynamics/projects/ergodic/configs/ks_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,12 @@ def get_config():
config = ml_collections.ConfigDict()
config.experiment = 'ks_1d'
# Train params
config.train_steps = 50_000
config.train_steps = 300_000
config.seed = 42
config.lr = 1e-4
config.use_lr_scheduler = True
config.metric_aggregation_steps = 50
config.save_interval_steps = 5_000
config.save_interval_steps = 30_000
config.max_checkpoints_to_keep = 10
# Data params
config.batch_size = 128
Expand All @@ -50,20 +51,19 @@ def get_config():
config.order_sobolev_norm = 1
config.noise_level = 0.0

# TODO(yairschiff): Split different models into separate configs
# Model params
config.model = 'PeriodicConvNetModel' # 'Fno'
# TODO(yairschiff): Split CNN and FNO into separate configs
########### PeriodicConvNetModel ################
# config.model = 'PeriodicConvNetModel'
# config.latent_dim = 48
# config.num_levels = 4
# config.num_processors = 4
# config.encoder_kernel_size = (5,)
# config.decoder_kernel_size = (5,)
# config.processor_kernel_size = (5,)
# config.padding = 'CIRCULAR'
# config.is_input_residual = True
config.latent_dim = 48
config.num_levels = 4
config.num_processors = 4
config.encoder_kernel_size = (5,)
config.decoder_kernel_size = (5,)
config.processor_kernel_size = (5,)
config.padding = 'CIRCULAR'
config.is_input_residual = True
########### FNO ################
config.model = 'Fno'
config.out_channels = 1
config.hidden_channels = 64
config.num_modes = (24,)
Expand All @@ -77,14 +77,17 @@ def get_config():

config.num_lookback_steps = 1
# Update num_time_steps and integrator based on num_lookback_steps setting
config.num_time_steps += config.num_lookback_steps - 1
if config.num_lookback_steps > 1:
config.num_time_steps += config.get_ref('num_lookback_steps') - 1
if config.get_ref('num_lookback_steps') > 1:
config.integrator = 'MultiStepDirect'
else:
config.integrator = 'OneStepDirect'
# Trainer params
config.rollout_weighting = 'geometric'
config.rollout_weighting_r = 0.9
config.rollout_weighting_clip = 10e-4
config.num_rollout_steps = 1
config.train_steps_per_cycle = 5_000
config.train_steps_per_cycle = 60_000
config.time_steps_increase_per_cycle = 1
config.use_curriculum = False # Sweepable
config.use_pushfwd = False # Sweepable
Expand All @@ -106,6 +109,8 @@ def skip(

if not use_curriculum and use_pushfwd:
return True
if measure_dist_lambda > 0.0 and measure_dist_k_lambda == 0.0:
return True
if (
measure_dist_type == 'SD'
and measure_dist_lambda == 0.0
Expand All @@ -119,87 +124,79 @@ def skip(
# use option --sweep=False in the command line to avoid sweeping
def sweep(add):
"""Define param sweep."""
for seed in [42]:
for normalize in [False, True]:
for batch_size in [128]:
for lr in [1e-4]:
for use_curriculum in [False, True]:
for use_pushfwd in [False, True]:
for measure_dist_type in ['MMD', 'SD']:
for measure_dist_lambda in [0.0, 1.0]:
for measure_dist_k_lambda in [0.0, 1.0, 100.0]:
if use_curriculum:
train_steps_per_cycle = 50_000
time_steps_increase_per_cycle = 1
else:
train_steps_per_cycle = 0
time_steps_increase_per_cycle = 0
if skip(
use_curriculum,
use_pushfwd,
measure_dist_lambda,
measure_dist_k_lambda,
measure_dist_type,
):
continue
add(
seed=seed,
batch_size=batch_size,
normalize=normalize,
lr=lr,
measure_dist_type=measure_dist_type,
train_steps_per_cycle=train_steps_per_cycle,
time_steps_increase_per_cycle=time_steps_increase_per_cycle,
use_curriculum=use_curriculum,
use_pushfwd=use_pushfwd,
measure_dist_lambda=measure_dist_lambda,
measure_dist_k_lambda=measure_dist_k_lambda,
)
# pylint: disable=line-too-long
for seed in [1, 11, 21, 31, 42]:
for normalize in [True]:
for model in ['PeriodicConvNetModel']:
for batch_size in [32, 64, 128, 256]:
for lr in [5e-4]:
for use_curriculum in [True]:
for use_pushfwd in [True]:
for measure_dist_type in ['MMD', 'SD']:
for measure_dist_lambda in [0.0, 1.0]:
for measure_dist_k_lambda in [0.0, 1.0, 10.0, 100.0, 1000.0]:
if use_curriculum:
train_steps_per_cycle = 60_000
time_steps_increase_per_cycle = 1
else:
train_steps_per_cycle = 0
time_steps_increase_per_cycle = 0
if skip(
use_curriculum,
use_pushfwd,
measure_dist_lambda,
measure_dist_k_lambda,
measure_dist_type,
):
continue
add(
seed=seed,
normalize=normalize,
model=model,
batch_size=batch_size,
lr=lr,
measure_dist_type=measure_dist_type,
train_steps_per_cycle=train_steps_per_cycle,
time_steps_increase_per_cycle=time_steps_increase_per_cycle,
use_curriculum=use_curriculum,
use_pushfwd=use_pushfwd,
measure_dist_lambda=measure_dist_lambda,
measure_dist_k_lambda=measure_dist_k_lambda,
)


# TODO(yairschiff): Ablation!
# def sweep(add):
# """Define param sweep."""
# for seed in [21, 42, 84]:
# for measure_dist_type in ['MMD', 'SD']:
# for batch_size in [32, 64, 128, 256]:
# for lr in [1e-3, 1e-4, 1e-5]:
# # Skipping 1-step objective
# for use_curriculum in [True]: # [False, True]:
# # Running grid search on just Pfwd
# for use_pushfwd in [True]: # [False, True]:
# # Skipping all x0 regs
# for regularize_measure_dist in [False]: # [False, True]:
# for regularize_measure_dist_k in [False, True]:
# for measure_dist_lambda in [0.0, 1.0, 100.0]:
# if use_curriculum:
# train_steps_per_cycle = 50_000
# time_steps_increase_per_cycle = 1
# else:
# train_steps_per_cycle = 0
# time_steps_increase_per_cycle = 0
# if skip(
# use_curriculum,
# use_pushfwd,
# regularize_measure_dist,
# regularize_measure_dist_k,
# measure_dist_lambda,
# measure_dist_type
# ):
# continue
# add(
# num_time_steps=61,
# train_steps=3_000_000,
# save_interval_steps=250_000,
# max_checkpoints_to_keep=3_000_000//250_000,
# seed=seed,
# batch_size=batch_size,
# lr=lr,
# measure_dist_type=measure_dist_type,
# train_steps_per_cycle=train_steps_per_cycle,
# time_steps_increase_per_cycle=time_steps_increase_per_cycle,
# use_curriculum=use_curriculum,
# use_pushfwd=use_pushfwd,
# regularize_measure_dist=regularize_measure_dist,
# regularize_measure_dist_k=regularize_measure_dist_k,
# measure_dist_lambda=measure_dist_lambda,
# )
# # pylint: disable=line-too-long
# for seed in [1, 11, 21, 31, 42]:
# for normalize in [True]:
# for model in ['PeriodicConvNetModel']:
# for batch_size in [32, 64, 128, 256, 512]:
# for lr in [5e-4]:
# for use_curriculum in [True]:
# for use_pushfwd in [True]:
# for measure_dist_type, measure_dist_lambda, measure_dist_k_lambda in [('MMD', 0.0, 10.0), ('SD', 0.0, 10.0)]: #[('MMD', 0.0, 0.0), ('MMD', 1.0, 1000.0), ('SD', 0.0, 1.0)]:
# train_steps_per_cycle = 60_000
# time_steps_increase_per_cycle = 1
# train_steps = 300_000
# max_checkpoints_to_keep = 10
# num_time_steps = 11
# add(
# seed=seed,
# train_steps=train_steps,
# max_checkpoints_to_keep=max_checkpoints_to_keep,
# num_time_steps=num_time_steps,
# normalize=normalize,
# model=model,
# batch_size=batch_size,
# lr=lr,
# measure_dist_type=measure_dist_type,
# train_steps_per_cycle=train_steps_per_cycle,
# time_steps_increase_per_cycle=time_steps_increase_per_cycle,
# use_curriculum=use_curriculum,
# use_pushfwd=use_pushfwd,
# measure_dist_lambda=measure_dist_lambda,
# measure_dist_k_lambda=measure_dist_k_lambda,
# )
# # pylint: enable=line-too-long
Loading

0 comments on commit 05dab19

Please sign in to comment.