Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 564562741
  • Loading branch information
The swirl_dynamics Authors committed Sep 12, 2023
1 parent a5786b2 commit 3c37b64
Show file tree
Hide file tree
Showing 5 changed files with 277 additions and 100 deletions.
55 changes: 5 additions & 50 deletions swirl_dynamics/projects/ergodic/configs/ks_1d_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,6 @@

import ml_collections

# pylint: disable=line-too-long
DATA_PATH = '/datasets/gcs_staging/hdf5/pde/1d/ks_trajectories.hdf5'
# pylint: enable=line-too-long


def get_config():
Expand All @@ -35,15 +32,20 @@ def get_config():
config.save_interval_steps = 50_000
config.max_checkpoints_to_keep = 10
# Data params
config.use_tfds = True
# config.batch_size = 4096
config.batch_size = 1024
config.num_time_steps = 11
config.time_stride = 1
config.dataset_path = DATA_PATH
config.dataset_name = DATA_NAME
config.spatial_downsample_factor = 1
config.normalize = False
config.add_noise = False
config.sobolev_norm = False
config.noise_level = 0.0
config.num_time_steps_eval = 600
config.batch_size_eval = 512

# Model params
# ######## Dilated Convolutions ########
Expand Down Expand Up @@ -149,50 +151,3 @@ def sweep(add):
measure_dist_k_lambda=measure_dist_k_lambda,
)


# def sweep(add):
# """Define param sweep."""
# for seed in [21, 42, 84]:
# for measure_dist_type in ['MMD', 'SD']:
# for batch_size in [32, 64, 128, 256]:
# for lr in [1e-3, 1e-4, 1e-5]:
# # Skipping 1-step objective
# for use_curriculum in [True]: # [False, True]:
# # Running grid search on just Pfwd
# for use_pushfwd in [True]: # [False, True]:
# # Skipping all x0 regs
# for regularize_measure_dist in [False]: # [False, True]:
# for regularize_measure_dist_k in [False, True]:
# for measure_dist_lambda in [0.0, 1.0, 100.0]:
# if use_curriculum:
# train_steps_per_cycle = 50_000
# time_steps_increase_per_cycle = 1
# else:
# train_steps_per_cycle = 0
# time_steps_increase_per_cycle = 0
# if skip(
# use_curriculum,
# use_pushfwd,
# regularize_measure_dist,
# regularize_measure_dist_k,
# measure_dist_lambda,
# measure_dist_type
# ):
# continue
# add(
# num_time_steps=61,
# train_steps=3_000_000,
# save_interval_steps=250_000,
# max_checkpoints_to_keep=3_000_000//250_000,
# seed=seed,
# batch_size=batch_size,
# lr=lr,
# measure_dist_type=measure_dist_type,
# train_steps_per_cycle=train_steps_per_cycle,
# time_steps_increase_per_cycle=time_steps_increase_per_cycle,
# use_curriculum=use_curriculum,
# use_pushfwd=use_pushfwd,
# regularize_measure_dist=regularize_measure_dist,
# regularize_measure_dist_k=regularize_measure_dist_k,
# measure_dist_lambda=measure_dist_lambda,
# )
56 changes: 28 additions & 28 deletions swirl_dynamics/projects/ergodic/configs/ns_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,45 +31,45 @@ def get_config():
# Train params
config.train_steps = 360_000
config.seed = 42
config.lr = 5e-4
config.lr = 5e-5
config.metric_aggregation_steps = 50
config.save_interval_steps = 50_000
config.max_checkpoints_to_keep = 10
# Data params
config.batch_size = 50
config.batch_size = 32
config.num_time_steps = 2
config.time_stride = 1
config.dataset_path = DATA_PATH
config.spatial_downsample_factor = 1
config.normalize = False
config.normalize = True
config.add_noise = False
config.noise_level = 0.0
# Model params
# config.num_lookback_steps = 1
# config.integrator = 'OneStepDirect'
# config.model = 'PeriodicConvNetModel'
# config.latent_dim = 48
# config.num_levels = 4
# config.num_processors = 4
# config.encoder_kernel_size = (3, 3)
# config.decoder_kernel_size = (3, 3)
# config.processor_kernel_size = (3, 3)
# config.padding = 'CIRCULAR'
# config.is_input_residual = True
config.num_lookback_steps = 1
config.integrator = 'OneStepDirect'
config.model = 'PeriodicConvNetModel'
config.latent_dim = 128
config.num_levels = 2
config.num_processors = 4
config.encoder_kernel_size = (5, 5)
config.decoder_kernel_size = (5, 5)
config.processor_kernel_size = (5, 5)
config.padding = 'CIRCULAR'
config.is_input_residual = True
########### FNO ################
config.num_lookback_steps = 2
config.integrator = 'MultiStepDirect'
config.model = 'FNO'
config.out_channels = 1
config.hidden_channels = 64
config.num_modes = (20, 20)
config.lifting_channels = 256
config.projection_channels = 256
config.num_blocks = 4
config.layers_per_block = 2
config.block_skip_type = 'identity'
config.fft_norm = 'forward'
config.separable = False
# config.num_lookback_steps = 2
# config.integrator = 'MultiStepDirect'
# config.model = 'FNO'
# config.out_channels = 1
# config.hidden_channels = 64
# config.num_modes = (20, 20)
# config.lifting_channels = 256
# config.projection_channels = 256
# config.num_blocks = 4
# config.layers_per_block = 2
# config.block_skip_type = 'identity'
# config.fft_norm = 'forward'
# config.separable = False
# Update num_time_steps based on num_lookback_steps setting
config.num_time_steps += config.num_lookback_steps - 1
# Trainer params
Expand All @@ -81,7 +81,7 @@ def get_config():
config.measure_dist_type = 'MMD' # Sweepable
config.measure_dist_downsample = 1
config.measure_dist_lambda = 0.0 # Sweepable
config.measure_dist_k_lambda = 0.0 # Sweepable
config.measure_dist_k_lambda = 1.0 # Sweepable
return config


Expand Down
108 changes: 108 additions & 0 deletions swirl_dynamics/projects/ergodic/configs/ns_2d_dist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
# Copyright 2023 The swirl_dynamics Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Default Hyperparameter configuration for Navier Stokes 2D.
"""

import ml_collections

# pylint: disable=line-too-long
DATA_PATH = '/datasets/hdf5/pde/2d/ns/ns_trajectories_from_caltech.hdf5'
# DATA_PATH = '/datasets/hdf5/pde/2d/ns/attractor_spectral_grid_256_spatial_downsample_4_dt_0.001_v0_3_warmup_40.0_t_final_200.0_nu_0.001_n_samples_2000_ntraj_train_128_ntraj_eval_32_ntraj_test_32_drag_0.1_wave_number_4_random_seeds_combined_4.hdf5'
# pylint: enable=line-too-long


def get_config():
"""Get the default hyperparameter configuration."""
config = ml_collections.ConfigDict()
config.experiment = 'ns_2d'
# Train params
config.train_steps = 360_000
config.seed = 42
config.lr = 5e-5
config.metric_aggregation_steps = 50
config.save_interval_steps = 50_000
config.max_checkpoints_to_keep = 10
# Data params
config.batch_size = 64
config.num_time_steps = 10
config.time_stride = 1
config.dataset_path = DATA_PATH
config.spatial_downsample_factor = 1
config.normalize = False
config.add_noise = False
config.noise_level = 0.0
config.sobolev_norm = False

# Model params
config.num_lookback_steps = 1
config.integrator = 'OneStepDirect'
config.model = 'PeriodicConvNetModel'
config.latent_dim = 128
config.num_levels = 2
config.num_processors = 4
config.encoder_kernel_size = (3, 3)
config.decoder_kernel_size = (3, 3)
config.processor_kernel_size = (3, 3)
config.padding = 'CIRCULAR'
config.is_input_residual = True
########### FNO ################
# config.num_lookback_steps = 2
# config.integrator = 'MultiStepDirect'
# config.model = 'FNO'
# config.out_channels = 1
# config.hidden_channels = 64
# config.num_modes = (20, 20)
# config.lifting_channels = 256
# config.projection_channels = 256
# config.num_blocks = 4
# config.layers_per_block = 2
# config.block_skip_type = 'identity'
# config.fft_norm = 'forward'
# config.separable = False
# Update num_time_steps based on num_lookback_steps setting
config.num_time_steps += config.num_lookback_steps - 1
# Trainer params
config.num_rollout_steps = 1
config.train_steps_per_cycle = 0
config.time_steps_increase_per_cycle = 1
config.use_curriculum = False # Sweepable
config.use_pushfwd = False # Sweepable
config.measure_dist_downsample = 1
config.measure_dist_lambda = 0.0 # Sweepable
config.measure_dist_k_lambda = 10.0 # Sweepable
config.measure_dist_type = 'MMD_DIST' # Sweepable
config.use_distributed = True
return config


# TODO(yairschiff): Refactor sweeps and experiment definition to use gin.
def sweep(add):
"""Define param sweep."""
for seed in [42]:
for measure_dist_type in ['MMD', 'SD']:
for measure_dist_k_lambda in [100.0, 1000.0]:
for measure_dist_lambda in [0.0]:
for measure_dist_downsample in [1, 2]:
if measure_dist_k_lambda == measure_dist_lambda == 0.0:
if measure_dist_type == 'SD' or measure_dist_downsample > 1:
continue # Avoid re-running baseline exp multiple times
add(
seed=seed,
measure_dist_type=measure_dist_type,
measure_dist_lambda=measure_dist_lambda,
measure_dist_k_lambda=measure_dist_k_lambda,
measure_dist_downsample=measure_dist_downsample,
)
66 changes: 44 additions & 22 deletions swirl_dynamics/projects/ergodic/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,28 +101,50 @@ def main(argv):
raise NotImplementedError(f"Unknown experiment: {config.experiment}")

# Dataloaders
train_loader, normalize_stats = utils.create_loader_from_hdf5(
num_time_steps=config.num_time_steps,
time_stride=config.time_stride,
batch_size=config.batch_size,
seed=config.seed,
dataset_path=config.dataset_path,
split="train",
normalize=config.normalize,
normalize_stats=None,
spatial_downsample_factor=config.spatial_downsample_factor,
)
eval_loader, _ = utils.create_loader_from_hdf5(
num_time_steps=-1,
time_stride=config.time_stride,
batch_size=-1,
seed=config.seed,
dataset_path=config.dataset_path,
split="eval",
normalize=config.normalize,
normalize_stats=normalize_stats,
spatial_downsample_factor=config.spatial_downsample_factor,
)
if "use_tfds" in config and config.use_tfds:
train_loader, normalize_stats = utils.create_loader_from_tfds(
num_time_steps=config.num_time_steps,
time_stride=config.time_stride,
batch_size=config.batch_size,
dataset_path=config.dataset_path,
dataset_name=config.dataset_name,
seed=config.seed,
normalize=config.normalize,
split="train",
)
eval_loader, _ = utils.create_loader_from_tfds(
num_time_steps=config.num_time_steps_eval,
time_stride=config.time_stride,
batch_size=config.batch_size_eval,
seed=config.seed,
dataset_path=config.dataset_path,
dataset_name=config.dataset_name,
normalize=config.normalize,
split="eval",
)
else:
train_loader, normalize_stats = utils.create_loader_from_hdf5(
num_time_steps=config.num_time_steps,
time_stride=config.time_stride,
batch_size=config.batch_size,
seed=config.seed,
dataset_path=config.dataset_path,
split="train",
normalize=config.normalize,
normalize_stats=None,
spatial_downsample_factor=config.spatial_downsample_factor,
)
eval_loader, _ = utils.create_loader_from_hdf5(
num_time_steps=-1,
time_stride=config.time_stride,
batch_size=-1,
seed=config.seed,
dataset_path=config.dataset_path,
split="eval",
normalize=config.normalize,
normalize_stats=normalize_stats,
spatial_downsample_factor=config.spatial_downsample_factor,
)

# Model
measure_dist_fn = choices.MeasureDistance(config.measure_dist_type).dispatch()
Expand Down
Loading

0 comments on commit 3c37b64

Please sign in to comment.