Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 565182278
  • Loading branch information
The swirl_dynamics Authors committed Sep 13, 2023
1 parent 587f7e2 commit 2b5cd84
Show file tree
Hide file tree
Showing 7 changed files with 22 additions and 11 deletions.
15 changes: 8 additions & 7 deletions swirl_dynamics/projects/ergodic/configs/ks_1d.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,21 +28,22 @@ def get_config():
config = ml_collections.ConfigDict()
config.experiment = 'ks_1d'
# Train params
config.train_steps = 500_000
config.train_steps = 50_000
config.seed = 42
config.lr = 1e-4
config.metric_aggregation_steps = 50
config.save_interval_steps = 50_000
config.save_interval_steps = 5_000
config.max_checkpoints_to_keep = 10
# Data params
config.batch_size = 128
config.batch_size = 32
config.num_time_steps = 2
config.time_stride = 1
config.dataset_path = DATA_PATH
config.spatial_downsample_factor = 1
config.normalize = False
config.add_noise = False
config.sobolev_norm = False
config.use_sobolev_norm = False
config.order_sobolev_norm = 1
config.noise_level = 0.0

# Model params
Expand Down Expand Up @@ -78,14 +79,14 @@ def get_config():
config.num_time_steps += config.num_lookback_steps - 1
# Trainer params
config.num_rollout_steps = 1
config.train_steps_per_cycle = 50_000
config.time_steps_increase_per_cycle = 0
config.train_steps_per_cycle = 5_000
config.time_steps_increase_per_cycle = 1
config.use_curriculum = False # Sweepable
config.use_pushfwd = False # Sweepable
config.measure_dist_type = 'MMD' # Sweepable
config.measure_dist_downsample = 1
config.measure_dist_lambda = 0.0 # Sweepable
config.measure_dist_k_lambda = 0.0 # Sweepable
config.measure_dist_k_lambda = 1.0 # Sweepable
return config


Expand Down
3 changes: 2 additions & 1 deletion swirl_dynamics/projects/ergodic/configs/ks_1d_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def get_config():
config.spatial_downsample_factor = 1
config.normalize = False
config.add_noise = False
config.sobolev_norm = False
config.use_sobolev_norm = False
config.order_sobolev_norm = 0
config.noise_level = 0.0
config.num_time_steps_eval = 600
config.batch_size_eval = 512
Expand Down
3 changes: 3 additions & 0 deletions swirl_dynamics/projects/ergodic/configs/lorenz63.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def get_config():
config.metric_aggregation_steps = 50
config.save_interval_steps = 50_000
config.max_checkpoints_to_keep = 10
config.use_sobolev_norm = False
config.order_sobolev_norm = 0

# Data params
config.batch_size = 4096
config.num_time_steps = 11
Expand Down
3 changes: 3 additions & 0 deletions swirl_dynamics/projects/ergodic/configs/ns_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,9 @@ def get_config():
config.metric_aggregation_steps = 50
config.save_interval_steps = 50_000
config.max_checkpoints_to_keep = 10
config.use_sobolev_norm = True
config.order_sobolev_norm = 1

# Data params
config.batch_size = 32
config.num_time_steps = 2
Expand Down
3 changes: 2 additions & 1 deletion swirl_dynamics/projects/ergodic/configs/ns_2d_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def get_config():
config.normalize = False
config.add_noise = False
config.noise_level = 0.0
config.sobolev_norm = False
config.use_sobolev_norm = False
config.order_sobolev_norm = 0

# Model params
config.num_lookback_steps = 1
Expand Down
2 changes: 2 additions & 0 deletions swirl_dynamics/projects/ergodic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,8 @@ def main(argv):
measure_dist_k_lambda=config.measure_dist_k_lambda,
num_lookback_steps=config.num_lookback_steps,
normalize_stats=normalize_stats,
use_sobolev_norm=config.use_sobolev_norm,
order_sobolev_norm=config.order_sobolev_norm,
)
model = stable_ar.StableARModel(conf=model_config)

Expand Down
4 changes: 2 additions & 2 deletions swirl_dynamics/projects/ergodic/stable_ar.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,10 +150,10 @@ def loss_fn(
if self.conf.use_sobolev_norm:
dim = len(pred.shape) - 3
l2 = ergodic_utils.sobolev_norm(
pred - true[:, -1, ...], s=self.conf.order_sobolev_norm, dim=dim
pred - true[:, 1:, ...], s=self.conf.order_sobolev_norm, dim=dim
)
else:
l2 = jnp.mean(jnp.square(pred - true[:, -1, ...]))
l2 = jnp.mean(jnp.square(pred - true[:, 1:, ...]))

# Gathering the metrics together.
loss = l2
Expand Down

0 comments on commit 2b5cd84

Please sign in to comment.