From e3040a665758348247d74c9c404ad0a161c9dc72 Mon Sep 17 00:00:00 2001 From: The swirl_dynamics Authors Date: Wed, 13 Sep 2023 19:11:49 -0700 Subject: [PATCH] Code update PiperOrigin-RevId: 565226900 --- swirl_dynamics/lib/solvers/ode.py | 66 ++++---- swirl_dynamics/lib/solvers/ode_test.py | 146 ++++++++++++++++-- swirl_dynamics/projects/ergodic/choices.py | 22 ++- .../projects/ergodic/colabs/demo.ipynb | 18 ++- .../projects/ergodic/configs/ks_1d.py | 127 ++++++++------- .../projects/ergodic/configs/lorenz63.py | 83 +++++----- .../projects/ergodic/configs/ns_2d.py | 120 ++++++++++---- swirl_dynamics/projects/ergodic/ks_1d.py | 8 + swirl_dynamics/projects/ergodic/lorenz63.py | 8 +- swirl_dynamics/projects/ergodic/ns_2d.py | 11 +- swirl_dynamics/projects/ergodic/stable_ar.py | 65 +++++--- 11 files changed, 465 insertions(+), 209 deletions(-) diff --git a/swirl_dynamics/lib/solvers/ode.py b/swirl_dynamics/lib/solvers/ode.py index 1ba17cc..62668f1 100644 --- a/swirl_dynamics/lib/solvers/ode.py +++ b/swirl_dynamics/lib/solvers/ode.py @@ -13,7 +13,6 @@ # limitations under the License. """Solvers for ordinary differential equations (ODEs).""" - from typing import Any, Protocol import flax @@ -60,7 +59,7 @@ def step( self, func: OdeDynamics, x0: Array, t0: Array, dt: Array, params: PyTree ) -> Array: """Advances the current state one step forward in time.""" - raise NotImplementedError + raise NotImplementedError("Scan solver must implement `step` method.") def __call__( self, func: OdeDynamics, x0: Array, tspan: Array, params: PyTree @@ -150,45 +149,43 @@ def __call__( ) +@flax.struct.dataclass class MultiStepScanOdeSolver: """ODE solver based on `jax.lax.scan` that uses one than one time step. Rather than x_{n+1} = f(x_n, t_n), we have x_{n+1} = f(x_{n-k}, x_{n-k+1}, ..., x_{n-1}, x_n, t_{n-k}, ..., t_{n-1}, t_n) for some 'num_lookback_steps' window k. + + Attributes: + time_axis_pos: move the time axis to the specified position in the output + tensor (by default it is at the 0th position). This attribute is used to + both indicate the temporal axis position of the input and where to place + the temporal axis of the output. """ - @staticmethod - def stack_timesteps_along_channel_dim(x: Array) -> Array: + time_axis_pos: int = 0 + + def stack_timesteps_along_channel_dim(self, x: Array) -> Array: """Helper method to package batches for multi-step solvers. Args: - x: Array of shape: (num_lookback_steps, state_dims, channels), where - state_dims can have ndim >= 1 + x: Array of containing axes for batch_size (potentially), + lookback_steps (e.g., temporal axis), spatial_dims, and channels, where + spatial_dims can have ndim >= 1 Returns: - Array of shape (state_dims, num_lookback_steps*channels) + Array where each time step in the temporal dim is concatenated along + the channel axis """ - orig_shape = x.shape - num_lookback_steps = orig_shape[0] - stacked_shape = list(orig_shape) - # All the previous timesteps are collapsed into one - stacked_shape[0] = 1 - # Concatenate steps along channel dim - stacked_shape[-1] *= num_lookback_steps - # For state_dims with ndim > 1, e.g. 2D grids, flatten state dims: - flattened_state_size = np.prod(list(x.shape[1:-1])) - x_flattened_state_shape = ( - (x.shape[0],) + (flattened_state_size,) + (x.shape[-1],) - ) - x = x.reshape(x_flattened_state_shape) - return x.swapaxes(0, 1).reshape(stacked_shape).squeeze(axis=0) + x = jnp.moveaxis(x, self.time_axis_pos, -2) + return jnp.reshape(x, x.shape[:-2] + (-1,)) def step( self, func: OdeDynamics, x0: Array, t0: Array, dt: Array, params: PyTree ) -> Array: """Advances the current state one step forward in time.""" - raise NotImplementedError + raise NotImplementedError("Scan solver must implement `step` method.") def __call__( self, @@ -202,21 +199,26 @@ def __call__( def scan_fun( state: tuple[Array, Array], t_next: Array ) -> tuple[tuple[Array, Array], Array]: - # x0 assumed to have shape: (lookback, state_dims, channels) + # Expected dimension for x0 is either: + # - (t, ...), if time_axis_pos == 0 + # - (batch_size, ..., t, ...), if time_axis_pos > 0 x0, t0 = state - x0_stack = self.stack_timesteps_along_channel_dim(x0)[None, ...] - # input to func has shape: (state_dims, channels*lookback) + x0_stack = self.stack_timesteps_along_channel_dim(x0) dt = t_next - t0 - # return item (x_next) has shape: state_dims x channels x_next = self.step(func, x0_stack, t0, dt, params) - # carry item has same shape as x0, where we first shift over the original - # input and append the new predicted state along the time dimension - x_carry = jnp.concatenate([x0[1:, ...], x_next], axis=0) - return (x_carry, t_next), x_next.squeeze(axis=0) + x_next = jnp.expand_dims(x_next, axis=self.time_axis_pos) + x_prev = jnp.take( + x0, np.arange(1, x0.shape[self.time_axis_pos]), + axis=self.time_axis_pos + ) + x_carry = jnp.concatenate([x_prev, x_next], axis=self.time_axis_pos) + return (x_carry, t_next), x_next.squeeze(axis=self.time_axis_pos) _, out = jax.lax.scan(scan_fun, (x0, tspan[0]), tspan[1:]) - # output of scan has shape: len(tspan) - 1 x state_dims x channels - return jnp.concatenate([x0, out], axis=0) + return jnp.concatenate( + [x0, jnp.moveaxis(out, 0, self.time_axis_pos)], + axis=self.time_axis_pos + ) class MultiStepDirect(MultiStepScanOdeSolver): diff --git a/swirl_dynamics/lib/solvers/ode_test.py b/swirl_dynamics/lib/solvers/ode_test.py index 5701985..7613f8b 100644 --- a/swirl_dynamics/lib/solvers/ode_test.py +++ b/swirl_dynamics/lib/solvers/ode_test.py @@ -12,6 +12,8 @@ # See the License for the specific language governing permissions and # limitations under the License. +import functools + from absl.testing import absltest from absl.testing import parameterized import jax @@ -25,6 +27,12 @@ def dummy_ode_dynamics(x, t, params): return jnp.ones_like(x) +def dummy_ode_dynamics_return_original_num_channel(x, t, params, num_channels): + """Assumes ode dynamics returns channel dim of 1.""" + del t, params + return jnp.ones(x.shape[:-1] + (num_channels,)) + + class OdeSolversTest(parameterized.TestCase): @parameterized.parameters( @@ -46,18 +54,55 @@ def test_output_shape_and_value(self, solver, backward): self.assertEqual(out.shape, (num_steps, x_dim)) np.testing.assert_allclose(out[-1], np.ones((x_dim,)) * tspan[-1]) - def test_move_time_axis_pos(self): + def test_output_shape_and_value_one_step_direct(self): dt = 0.1 num_steps = 10 x_dim = 5 + tspan = jnp.arange(num_steps) * dt + solver = ode.OneStepDirect() + out = solver(dummy_ode_dynamics, jnp.zeros((x_dim,)), tspan, {}) + self.assertEqual(out.shape, (num_steps, x_dim)) + np.testing.assert_array_equal(out[-1], np.ones((x_dim,))) + + @parameterized.parameters( + {"time_axis_pos": 1, "x_dim": (5,)}, + {"time_axis_pos": 2, "x_dim": (5, 5)}, + ) + def test_move_time_axis_pos(self, time_axis_pos, x_dim): + dt = 0.1 + num_steps = 10 batch_sz = 6 + input_shape = (batch_sz,) + x_dim tspan = jnp.arange(num_steps) * dt - out = ode.ExplicitEuler(time_axis_pos=1)( - dummy_ode_dynamics, jnp.zeros((batch_sz, x_dim)), tspan, {} + solver = ode.ExplicitEuler(time_axis_pos=time_axis_pos) + expected_output_shape = ( + input_shape[:time_axis_pos] + (num_steps,) + input_shape[time_axis_pos:] ) - self.assertEqual(out.shape, (batch_sz, num_steps, x_dim)) + out = solver(dummy_ode_dynamics, jnp.zeros(input_shape), tspan, {}) + self.assertEqual(out.shape, expected_output_shape) np.testing.assert_allclose( - out[:, -1], np.ones((batch_sz, x_dim)) * tspan[-1] + jnp.moveaxis(out, time_axis_pos, 1)[:, -1], + np.ones(input_shape) * tspan[-1], + ) + + @parameterized.parameters( + {"time_axis_pos": 1, "x_dim": (5,)}, + {"time_axis_pos": 2, "x_dim": (5, 5)}, + ) + def test_move_time_axis_pos_one_step_direct(self, time_axis_pos, x_dim): + dt = 0.1 + num_steps = 10 + batch_sz = 6 + input_shape = (batch_sz,) + x_dim + tspan = jnp.arange(num_steps) * dt + solver = ode.OneStepDirect(time_axis_pos=time_axis_pos) + out = solver(dummy_ode_dynamics, jnp.zeros(input_shape), tspan, {}) + expected_output_shape = ( + input_shape[:time_axis_pos] + (num_steps,) + input_shape[time_axis_pos:] + ) + self.assertEqual(out.shape, expected_output_shape) + np.testing.assert_array_equal( + jnp.moveaxis(out, time_axis_pos, 1)[:, -1], np.ones(input_shape) ) @parameterized.parameters((np.arange(10) * -1,), (np.zeros(10),)) @@ -70,37 +115,108 @@ def test_dopri45_backward_error(self, tspan): class MultiStepOdeSolversTest(parameterized.TestCase): @parameterized.product( + time_axis_pos=(0, 1, 2), + batch_size=(1, 8), state_dim=((512,), (64, 64), (32, 32, 32)), channels=(1, 2, 3), num_lookback_steps=(2, 4, 8), ) def test_stacked_output_shape_and_value( self, + time_axis_pos, + batch_size, state_dim, channels, num_lookback_steps, ): + # Setup initial shape with time in axis 0 + input_shape = [num_lookback_steps, batch_size] + input_shape.extend(state_dim) + input_shape.append(channels) + # Re-order shape to match time_axis_pos (for time_axis_pos=0, this is no OP) + input_shape[0] = input_shape[time_axis_pos] + input_shape[time_axis_pos] = num_lookback_steps + rng = jax.random.PRNGKey(0) - input_shape = (num_lookback_steps,) + state_dim + (channels,) input_state = jax.random.normal(rng, input_shape) - input_state_stacked = ( - ode.MultiStepScanOdeSolver.stack_timesteps_along_channel_dim( - input_state - ) + input_state_stacked = ode.MultiStepScanOdeSolver( + time_axis_pos=time_axis_pos + ).stack_timesteps_along_channel_dim(input_state) + expected_output_shape = ( + input_shape[:time_axis_pos] + + input_shape[time_axis_pos + 1 : -1] + + [channels * num_lookback_steps] ) # Check that expected shapes match - self.assertEqual( - input_state_stacked.shape, state_dim + (channels * num_lookback_steps,) - ) - # Check that timesteps were correctly concatenated along channel dim + self.assertEqual(input_state_stacked.shape, tuple(expected_output_shape)) + # Check that timesteps correctly concatenated along channel dim for w in range(num_lookback_steps): c_start = channels * w c_end = channels * (w + 1) np.testing.assert_array_equal( - input_state[w, ...], + jnp.moveaxis(input_state, time_axis_pos, 0)[w, ...], input_state_stacked[..., c_start:c_end], ) + @parameterized.product( + time_axis_pos=(0, 1, 2), + batch_size=(1, 8), + state_dim=((512,), (64, 64), (32, 32, 32)), + channels=(1, 2, 3), + num_lookback_steps=(2, 4, 8), + ) + def test_output_shape_and_value_multi_step_direct( + self, + time_axis_pos, + batch_size, + state_dim, + channels, + num_lookback_steps, + ): + # Setup initial shape with time in axis 0 + input_shape = [num_lookback_steps, batch_size] + input_shape.extend(state_dim) + input_shape.append(channels) + # Re-order shape to match time_axis_pos (for time_axis_pos=0, this is no OP) + input_shape[0] = input_shape[time_axis_pos] + input_shape[time_axis_pos] = num_lookback_steps + dt = 0.1 + num_steps = 10 + tspan = jnp.arange(num_steps) * dt + solver = ode.MultiStepDirect(time_axis_pos=time_axis_pos) + out = solver( + functools.partial( + dummy_ode_dynamics_return_original_num_channel, + num_channels=channels, + ), + jnp.zeros(input_shape), + tspan, + {}, + ) + expected_output_shape = ( + input_shape[:time_axis_pos] + + [num_steps] + + input_shape[time_axis_pos + 1 :] + ) + out = jnp.take( + out, + np.arange(num_lookback_steps - 1, out.shape[time_axis_pos]), + axis=time_axis_pos, + ) + self.assertEqual(out.shape, tuple(expected_output_shape)) + np.testing.assert_array_equal( + jnp.take( + out, + np.arange(1, out.shape[time_axis_pos]), + axis=time_axis_pos, + ), + jnp.take( + np.ones(expected_output_shape), + np.arange(1, out.shape[time_axis_pos]), + axis=time_axis_pos, + ), + ) + if __name__ == "__main__": absltest.main() diff --git a/swirl_dynamics/projects/ergodic/choices.py b/swirl_dynamics/projects/ergodic/choices.py index c94b502..a7fd049 100644 --- a/swirl_dynamics/projects/ergodic/choices.py +++ b/swirl_dynamics/projects/ergodic/choices.py @@ -59,17 +59,19 @@ class Integrator(enum.Enum): def dispatch( self, - ) -> type[ode.ScanOdeSolver] | type[ode.MultiStepScanOdeSolver]: + ) -> ode.ScanOdeSolver | ode.MultiStepScanOdeSolver: """Dispatch integator. Returns: ScanOdeSolver | MultiStepScanOdeSolver """ + # TODO(yairschiff): Profile if the moveaxis call required here introduces a + # bottleneck return { - "ExplicitEuler": ode.ExplicitEuler, - "RungeKutta4": ode.RungeKutta4, - "OneStepDirect": ode.OneStepDirect, - "MultiStepDirect": ode.MultiStepDirect, + "ExplicitEuler": ode.ExplicitEuler(time_axis_pos=1), + "RungeKutta4": ode.RungeKutta4(time_axis_pos=1), + "OneStepDirect": ode.OneStepDirect(time_axis_pos=1), + "MultiStepDirect": ode.MultiStepDirect(time_axis_pos=1), }[self.value] @@ -109,7 +111,8 @@ def dispatch( class Model(enum.Enum): """Model choices.""" - FNO = "FNO" + FNO = "Fno" + FNO_2D = "Fno2d" MLP = "MLP" PERIODIC_CONV_NET_MODEL = "PeriodicConvNetModel" @@ -148,4 +151,11 @@ def dispatch(self, conf: ml_collections.ConfigDict) -> nn.Module: fft_norm=conf.fft_norm, separable=conf.separable, ) + if self.value == Model.FNO_2D.value: + return fno.Fno2d( + out_channels=conf.out_channels, + num_modes=conf.num_modes, + width=conf.width, + fft_norm=conf.fft_norm + ) raise ValueError() diff --git a/swirl_dynamics/projects/ergodic/colabs/demo.ipynb b/swirl_dynamics/projects/ergodic/colabs/demo.ipynb index bf9717d..02ab8c4 100644 --- a/swirl_dynamics/projects/ergodic/colabs/demo.ipynb +++ b/swirl_dynamics/projects/ergodic/colabs/demo.ipynb @@ -129,11 +129,11 @@ "outputs": [], "source": [ "experiment = \"ks_1d\" #@param choices=['lorenz63', 'ks_1d', 'ns_2d']\n", - "batch_size = 50 #@param {type:\"integer\"}\n", + "batch_size = 128 #@param {type:\"integer\"}\n", "measure_dist_type = \"MMD\" #@param choices=['MMD', 'SD']\n", "normalize = False #@param {type:\"boolean\"}\n", "add_noise = False #@param {type:\"boolean\"}\n", - "use_curriculum = False #@param {type:\"boolean\"}\n", + "use_curriculum = True #@param {type:\"boolean\"}\n", "use_pushfwd = False #@param {type:\"boolean\"}\n", "measure_dist_lambda = 0.0 #@param {type:\"number\"}\n", "measure_dist_k_lambda = 0.0 #@param {type:\"number\"}\n", @@ -311,7 +311,7 @@ " ),\n", " callbacks.TqdmProgressBar(\n", " total_train_steps=config.train_steps,\n", - " train_monitors=[\"loss\", \"measure_dist\", \"measure_dist_k\"],\n", + " train_monitors=[\"rollout\", \"loss\", \"measure_dist\", \"measure_dist_k\"],\n", " eval_monitors=[\"sd\"],\n", " ),\n", " fig_callback_cls()\n", @@ -323,7 +323,17 @@ "cell_type": "code", "execution_count": null, "metadata": { - "id": "CCBlVwYaKrhy" + "id": "cxxB4OZBuAA2" + }, + "outputs": [], + "source": [ + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "2AT2YThCyNPt" }, "outputs": [], "source": [] diff --git a/swirl_dynamics/projects/ergodic/configs/ks_1d.py b/swirl_dynamics/projects/ergodic/configs/ks_1d.py index 56fb6fb..6554cc7 100644 --- a/swirl_dynamics/projects/ergodic/configs/ks_1d.py +++ b/swirl_dynamics/projects/ergodic/configs/ks_1d.py @@ -35,9 +35,13 @@ def get_config(): config.save_interval_steps = 5_000 config.max_checkpoints_to_keep = 10 # Data params - config.batch_size = 32 - config.num_time_steps = 2 - config.time_stride = 1 + config.batch_size = 128 + # num_time_steps is length of ground truth trajectories in each batch from + # dataloader (this differs from num_time_steps in Trainer.preprocess_batch + # functions that corresponds to the len of ground truth trajectories to + # actually pass to model, which can vary during training). + config.num_time_steps = 11 + config.time_stride = 1 # factor for downsampling time dim of ground truth config.dataset_path = DATA_PATH config.spatial_downsample_factor = 1 config.normalize = False @@ -46,37 +50,38 @@ def get_config(): config.order_sobolev_norm = 1 config.noise_level = 0.0 + # TODO(yairschiff): Split different models into separate configs # Model params + ########### PeriodicConvNetModel ################ + # config.model = 'PeriodicConvNetModel' + # config.latent_dim = 48 + # config.num_levels = 4 + # config.num_processors = 4 + # config.encoder_kernel_size = (5,) + # config.decoder_kernel_size = (5,) + # config.processor_kernel_size = (5,) + # config.padding = 'CIRCULAR' + # config.is_input_residual = True + ########### FNO ################ + config.model = 'Fno' + config.out_channels = 1 + config.hidden_channels = 64 + config.num_modes = (24,) + config.lifting_channels = 256 + config.projection_channels = 256 + config.num_blocks = 4 + config.layers_per_block = 2 + config.block_skip_type = 'identity' + config.fft_norm = 'forward' + config.separable = False - ####### Dilated convolutions ########## config.num_lookback_steps = 1 - config.integrator = 'OneStepDirect' - config.model = 'PeriodicConvNetModel' - config.latent_dim = 48 - config.num_levels = 4 - config.num_processors = 4 - config.encoder_kernel_size = (5,) - config.decoder_kernel_size = (5,) - config.processor_kernel_size = (5,) - config.padding = 'CIRCULAR' - config.is_input_residual = True - - # ########### FNO ################ - # config.num_lookback_steps = 2 - # config.integrator = 'MultiStepDirect' - # config.model = 'FNO' - # config.out_channels = 1 - # config.hidden_channels = 64 - # config.num_modes = (24,) - # config.lifting_channels = 256 - # config.projection_channels = 256 - # config.num_blocks = 4 - # config.layers_per_block = 2 - # config.block_skip_type = 'identity' - # config.fft_norm = 'forward' - # config.separable = False - # Update num_time_steps based on num_lookback_steps setting + # Update num_time_steps and integrator based on num_lookback_steps setting config.num_time_steps += config.num_lookback_steps - 1 + if config.num_lookback_steps > 1: + config.integrator = 'MultiStepDirect' + else: + config.integrator = 'OneStepDirect' # Trainer params config.num_rollout_steps = 1 config.train_steps_per_cycle = 5_000 @@ -115,39 +120,41 @@ def skip( def sweep(add): """Define param sweep.""" for seed in [42]: - for measure_dist_type in ['MMD', 'SD']: + for normalize in [False, True]: for batch_size in [128]: for lr in [1e-4]: - for use_curriculum in [True]: + for use_curriculum in [False, True]: for use_pushfwd in [False, True]: - for measure_dist_lambda in [0.0, 1.0, 100.0]: - for measure_dist_k_lambda in [0.0, 1.0, 100.0]: - if use_curriculum: - train_steps_per_cycle = 50_000 - time_steps_increase_per_cycle = 1 - else: - train_steps_per_cycle = 0 - time_steps_increase_per_cycle = 0 - if skip( - use_curriculum, - use_pushfwd, - measure_dist_lambda, - measure_dist_k_lambda, - measure_dist_type, - ): - continue - add( - seed=seed, - batch_size=batch_size, - lr=lr, - measure_dist_type=measure_dist_type, - train_steps_per_cycle=train_steps_per_cycle, - time_steps_increase_per_cycle=time_steps_increase_per_cycle, - use_curriculum=use_curriculum, - use_pushfwd=use_pushfwd, - measure_dist_lambda=measure_dist_lambda, - measure_dist_k_lambda=measure_dist_k_lambda, - ) + for measure_dist_type in ['MMD', 'SD']: + for measure_dist_lambda in [0.0, 1.0]: + for measure_dist_k_lambda in [0.0, 1.0, 100.0]: + if use_curriculum: + train_steps_per_cycle = 50_000 + time_steps_increase_per_cycle = 1 + else: + train_steps_per_cycle = 0 + time_steps_increase_per_cycle = 0 + if skip( + use_curriculum, + use_pushfwd, + measure_dist_lambda, + measure_dist_k_lambda, + measure_dist_type, + ): + continue + add( + seed=seed, + batch_size=batch_size, + normalize=normalize, + lr=lr, + measure_dist_type=measure_dist_type, + train_steps_per_cycle=train_steps_per_cycle, + time_steps_increase_per_cycle=time_steps_increase_per_cycle, + use_curriculum=use_curriculum, + use_pushfwd=use_pushfwd, + measure_dist_lambda=measure_dist_lambda, + measure_dist_k_lambda=measure_dist_k_lambda, + ) # def sweep(add): diff --git a/swirl_dynamics/projects/ergodic/configs/lorenz63.py b/swirl_dynamics/projects/ergodic/configs/lorenz63.py index 872deda..28ec202 100644 --- a/swirl_dynamics/projects/ergodic/configs/lorenz63.py +++ b/swirl_dynamics/projects/ergodic/configs/lorenz63.py @@ -19,7 +19,7 @@ import ml_collections # pylint: disable=line-too-long -DATA_PATH = '/datasets/gcs_staging/hdf5/ode/1d/lorenz63_trajectories.hdf5' +DATA_PATH = '/datasets/hdf5/ode/lorenz63_seed_42_ntraj_train_5000_ntraj_eval_10000_ntraj_test_10000_nsteps_100000_nwarmup_100000_dt_0.001_dt_downsample_10.hdf5' # pylint: enable=line-too-long @@ -38,9 +38,13 @@ def get_config(): config.order_sobolev_norm = 0 # Data params - config.batch_size = 4096 + config.batch_size = 2048 + # num_time_steps is length of ground truth trajectories in each batch from + # dataloader (this differs from num_time_steps in Trainer.preprocess_batch + # functions that corresponds to the len of ground truth trajectories to + # actually pass to model, which can vary during training). config.num_time_steps = 11 - config.time_stride = 10 + config.time_stride = 1 # factor for downsampling time dim of ground truth config.dataset_path = DATA_PATH config.spatial_downsample_factor = 1 config.normalize = False @@ -77,6 +81,8 @@ def skip( if not use_curriculum and use_pushfwd: return True + if measure_dist_lambda > 0.0 and measure_dist_k_lambda == 0.0: + return True if ( measure_dist_type == 'SD' and measure_dist_lambda == 0.0 @@ -89,35 +95,42 @@ def skip( # TODO(yairschiff): Refactor sweeps and experiment definition to use gin. def sweep(add): """Define param sweep.""" - for seed in [42]: - for measure_dist_type in ['MMD', 'SD']: - for batch_size in [2048]: - for use_curriculum in [False, True]: - for use_pushfwd in [False, True]: - for measure_dist_lambda in [0.0, 1.0, 10.0, 100.0, 1000.0]: - for measure_dist_k_lambda in [0.0, 1.0, 10.0, 100.0, 1000.0]: - if use_curriculum: - train_steps_per_cycle = 50_000 - time_steps_increase_per_cycle = 1 - else: - train_steps_per_cycle = 0 - time_steps_increase_per_cycle = 0 - if skip( - use_curriculum, - use_pushfwd, - measure_dist_lambda, - measure_dist_k_lambda, - measure_dist_type, - ): - continue - add( - seed=seed, - batch_size=batch_size, - measure_dist_type=measure_dist_type, - train_steps_per_cycle=train_steps_per_cycle, - time_steps_increase_per_cycle=time_steps_increase_per_cycle, - use_curriculum=use_curriculum, - use_pushfwd=use_pushfwd, - measure_dist_lambda=measure_dist_lambda, - measure_dist_k_lambda=measure_dist_k_lambda, - ) + for seed in [21, 42]: + for batch_size in [2048]: + for time_stride in [50]: + for normalize in [True]: + for measure_dist_type in ['MMD', 'SD']: + for use_curriculum in [True]: + for use_pushfwd in [False]: + for measure_dist_lambda in [0.0]: #, 1.0]: + for measure_dist_k_lambda in [0.0]: #, 1.0, 1000.0]: + if use_curriculum: + train_steps = 2_000_000 + train_steps_per_cycle = 200_000 + time_steps_increase_per_cycle = 1 + else: + train_steps = 500_000 + train_steps_per_cycle = 0 + time_steps_increase_per_cycle = 0 + if skip( + use_curriculum, + use_pushfwd, + measure_dist_lambda, + measure_dist_k_lambda, + measure_dist_type, + ): + continue + add( + seed=seed, + batch_size=batch_size, + time_stride=time_stride, + normalize=normalize, + measure_dist_type=measure_dist_type, + train_steps=train_steps, + train_steps_per_cycle=train_steps_per_cycle, + time_steps_increase_per_cycle=time_steps_increase_per_cycle, + use_curriculum=use_curriculum, + use_pushfwd=use_pushfwd, + measure_dist_lambda=measure_dist_lambda, + measure_dist_k_lambda=measure_dist_k_lambda, + ) diff --git a/swirl_dynamics/projects/ergodic/configs/ns_2d.py b/swirl_dynamics/projects/ergodic/configs/ns_2d.py index e29b930..06ae6dd 100644 --- a/swirl_dynamics/projects/ergodic/configs/ns_2d.py +++ b/swirl_dynamics/projects/ergodic/configs/ns_2d.py @@ -33,36 +33,40 @@ def get_config(): config.seed = 42 config.lr = 5e-5 config.metric_aggregation_steps = 50 - config.save_interval_steps = 50_000 + config.save_interval_steps = 36_000 config.max_checkpoints_to_keep = 10 config.use_sobolev_norm = True config.order_sobolev_norm = 1 # Data params - config.batch_size = 32 - config.num_time_steps = 2 - config.time_stride = 1 + config.batch_size = 50 + # num_time_steps is length of ground truth trajectories in each batch from + # dataloader (this differs from num_time_steps in Trainer.preprocess_batch + # functions that corresponds to the len of ground truth trajectories to + # actually pass to model, which can vary during training). + config.num_time_steps = 11 + config.time_stride = 1 # factor for downsampling time dim of ground truth config.dataset_path = DATA_PATH config.spatial_downsample_factor = 1 config.normalize = True config.add_noise = False config.noise_level = 0.0 + # Model params - config.num_lookback_steps = 1 - config.integrator = 'OneStepDirect' + ########### PeriodicConvNetModel ################ config.model = 'PeriodicConvNetModel' - config.latent_dim = 128 - config.num_levels = 2 + config.latent_dim = 96 + config.num_levels = 4 config.num_processors = 4 - config.encoder_kernel_size = (5, 5) - config.decoder_kernel_size = (5, 5) - config.processor_kernel_size = (5, 5) + config.encoder_kernel_size = (3, 3) + config.decoder_kernel_size = (3, 3) + config.processor_kernel_size = (3, 3) config.padding = 'CIRCULAR' config.is_input_residual = True - ########### FNO ################ - # config.num_lookback_steps = 2 - # config.integrator = 'MultiStepDirect' - # config.model = 'FNO' + + # ########### FNO ################ + # config.num_lookback_steps = 1 + # config.model = 'Fno' # config.out_channels = 1 # config.hidden_channels = 64 # config.num_modes = (20, 20) @@ -73,8 +77,20 @@ def get_config(): # config.block_skip_type = 'identity' # config.fft_norm = 'forward' # config.separable = False - # Update num_time_steps based on num_lookback_steps setting - config.num_time_steps += config.num_lookback_steps - 1 + + ########### FNO 2D ############### + # config.model = 'Fno2d' + # config.out_channels = 1 + # config.num_modes = (20, 20) + # config.width = 128 + # config.fft_norm = 'ortho' + + config.num_lookback_steps = 1 + # Update num_time_steps and integrator based on num_lookback_steps setting + if config.num_lookback_steps > 1: + config.integrator = 'MultiStepDirect' + else: + config.integrator = 'OneStepDirect' # Trainer params config.num_rollout_steps = 1 config.train_steps_per_cycle = 0 @@ -88,21 +104,63 @@ def get_config(): return config +def skip( + use_curriculum: bool, + use_pushfwd: bool, + measure_dist_lambda: float, + measure_dist_k_lambda: float, + measure_dist_type: str, +) -> bool: + """Helper method for avoiding unwanted runs in sweep.""" + + if not use_curriculum and use_pushfwd: + return True + if ( + measure_dist_type == 'SD' + and measure_dist_lambda == 0.0 + and measure_dist_k_lambda == 0.0 + ): + return True + return False + + # TODO(yairschiff): Refactor sweeps and experiment definition to use gin. +# use option --sweep=False in the command line to avoid sweeping def sweep(add): """Define param sweep.""" for seed in [42]: - for measure_dist_type in ['MMD', 'SD']: - for measure_dist_k_lambda in [100.0, 1000.0]: - for measure_dist_lambda in [0.0]: - for measure_dist_downsample in [1, 2]: - if measure_dist_k_lambda == measure_dist_lambda == 0.0: - if measure_dist_type == 'SD' or measure_dist_downsample > 1: - continue # Avoid re-running baseline exp multiple times - add( - seed=seed, - measure_dist_type=measure_dist_type, - measure_dist_lambda=measure_dist_lambda, - measure_dist_k_lambda=measure_dist_k_lambda, - measure_dist_downsample=measure_dist_downsample, - ) + for normalize in [False, True]: + for measure_dist_type in ['MMD', 'SD']: + for batch_size in [50]: + for lr in [5e-4]: + for use_curriculum in [False]: + for use_pushfwd in [False]: + for measure_dist_lambda in [0.0, 1.0]: + for measure_dist_k_lambda in [0.0, 1.0, 100.0]: + if use_curriculum: + train_steps_per_cycle = 72_000 + time_steps_increase_per_cycle = 1 + else: + train_steps_per_cycle = 0 + time_steps_increase_per_cycle = 0 + if skip( + use_curriculum, + use_pushfwd, + measure_dist_lambda, + measure_dist_k_lambda, + measure_dist_type, + ): + continue + add( + seed=seed, + batch_size=batch_size, + normalize=normalize, + lr=lr, + measure_dist_type=measure_dist_type, + train_steps_per_cycle=train_steps_per_cycle, + time_steps_increase_per_cycle=time_steps_increase_per_cycle, + use_curriculum=use_curriculum, + use_pushfwd=use_pushfwd, + measure_dist_lambda=measure_dist_lambda, + measure_dist_k_lambda=measure_dist_k_lambda, + ) diff --git a/swirl_dynamics/projects/ergodic/ks_1d.py b/swirl_dynamics/projects/ergodic/ks_1d.py index 989db20..e571141 100644 --- a/swirl_dynamics/projects/ergodic/ks_1d.py +++ b/swirl_dynamics/projects/ergodic/ks_1d.py @@ -31,6 +31,14 @@ def plot_trajectories( dt, x_grid, traj_length, trajs, pred_trajs, case_ids=(11, 32, 67, 89) ): """Plot sample trajectories.""" + assert trajs.shape[0] > max(case_ids), ( + "Ground truth trajectories do not contain enough samples" + f" ({trajs.shape[0]}) to select trajectory number {max(case_ids)}." + ) + assert pred_trajs.shape[0] > max(case_ids), ( + "Prediced trajectories do not contain enough samples" + f" ({pred_trajs.shape[0]}) to select trajectory number {max(case_ids)}." + ) plot_time = jnp.arange(traj_length) * dt t_max = plot_time.max() fig = plt.figure(figsize=(6, 6 * len(case_ids)), constrained_layout=True) diff --git a/swirl_dynamics/projects/ergodic/lorenz63.py b/swirl_dynamics/projects/ergodic/lorenz63.py index a3d4cd8..36297d6 100644 --- a/swirl_dynamics/projects/ergodic/lorenz63.py +++ b/swirl_dynamics/projects/ergodic/lorenz63.py @@ -209,6 +209,11 @@ def plot_correlations(dt, traj_length, trajs, pred_trajs): class Lorenz63PlotFigures(stable_ar.PlotFigures): """Lorenz 63 plotting.""" + def __init__(self, corr_plot_steps: int = 2000): + super().__init__() + # Correlation breaks down early, do not need all the steps + self.corr_plot_steps = corr_plot_steps + def on_eval_batches_end( self, trainer: callbacks.Trainer, eval_metrics: Mapping[str, Array] ) -> None: @@ -231,8 +236,7 @@ def on_eval_batches_end( figs.update( plot_correlations( dt=dt, - # Correlation breaks down early, do not need all the steps - traj_length=min(traj_length, 20), + traj_length=min(traj_length, self.corr_plot_steps), trajs=eval_metrics["all_trajs"]["trajs"], pred_trajs=eval_metrics["all_trajs"]["pred_trajs"], ) diff --git a/swirl_dynamics/projects/ergodic/ns_2d.py b/swirl_dynamics/projects/ergodic/ns_2d.py index c1bbf19..a9793e5 100644 --- a/swirl_dynamics/projects/ergodic/ns_2d.py +++ b/swirl_dynamics/projects/ergodic/ns_2d.py @@ -29,12 +29,19 @@ def plot_trajectories( dt, traj_lengths, trajs, pred_trajs, case_ids=(11, 32, 67, 89) ): """Plot sample trajectories.""" + assert trajs.shape[0] > max(case_ids), ( + "Ground truth trajectories do not contain enough samples" + f" ({trajs.shape[0]}) to select trajectory number {max(case_ids)}." + ) + assert pred_trajs.shape[0] > max(case_ids), ( + "Prediced trajectories do not contain enough samples" + f" ({pred_trajs.shape[0]}) to select trajectory number {max(case_ids)}." + ) fig = plt.figure( figsize=(3 * len(traj_lengths), 6 * len(case_ids)), constrained_layout=True, ) subfigs = fig.subfigures(nrows=len(case_ids), ncols=1) - for case_id, subfig in zip(case_ids, subfigs): ax = subfig.subplots( nrows=2, ncols=len(traj_lengths), sharex=True, sharey=True @@ -69,7 +76,7 @@ def on_eval_batches_end( figs.update( plot_trajectories( dt=dt, - traj_lengths=[0, traj_length // 4, traj_length // 2, traj_length-1], + traj_lengths=[0, 1, 2, 5, traj_length // 2, traj_length - 1], trajs=eval_metrics["all_trajs"]["trajs"], pred_trajs=eval_metrics["all_trajs"]["pred_trajs"], ) diff --git a/swirl_dynamics/projects/ergodic/stable_ar.py b/swirl_dynamics/projects/ergodic/stable_ar.py index 64e889f..f3729c7 100644 --- a/swirl_dynamics/projects/ergodic/stable_ar.py +++ b/swirl_dynamics/projects/ergodic/stable_ar.py @@ -57,7 +57,7 @@ class StableARModelConfig: num_lookback_steps: int = 1 use_sobolev_norm: bool = False order_sobolev_norm: int = 1 - normalize_stats: dict[str, Array] | None = None + normalize_stats: dict[str, Array | None] | None = None @dataclasses.dataclass(kw_only=True) @@ -67,14 +67,11 @@ class StableARModel(models.BaseModel): conf: StableARModelConfig def __post_init__(self): - pred_integrator = self.conf.integrator.dispatch()() - pred_integrator = functools.partial( + pred_integrator = self.conf.integrator.dispatch() + self.pred_integrator = functools.partial( pred_integrator, solver_utils.nn_module_to_ode_dynamics(self.conf.dynamics_model), ) - self.vmapped_pred_integrator = jax.vmap( - pred_integrator, in_axes=(0, None, None) - ) # TODO(lzepedanunez): check if this is compatible with distributed training. self.vmapped_measure_dist = jax.vmap(self.conf.measure_dist, in_axes=(1, 1)) @@ -90,6 +87,10 @@ def loss_fn( mutables: PyTree, ) -> models.LossAndAux: """Computes training loss and metrics.""" + # For expected shape comments below, ... corresponds to: + # - Lorenz: 3 + # - KS: spatial_dim, 1 + # - NS: spatial_dim, spatial_dim, 1 true = batch["true"] x0 = batch["x0"] # When using data parallelism, it will add an extra dimension due to the @@ -102,18 +103,22 @@ def loss_fn( x0 += noise if self.conf.use_pushfwd: # Rollout for t-1 steps with stop gradient + # Expected shape: (bsz, num_rollout_steps+num_lookback_steps+1, ...) pred_pushfwd = jax.lax.stop_gradient( - self.vmapped_pred_integrator( - x0, tspan[:-1], dict(params=params, **mutables) + self.pred_integrator( + x0, batch["tspan"][:-1], dict(params=params, **mutables) ) ) if self.conf.num_lookback_steps > 1: + # Expected shape: (batch_size, num_lookback_steps, ...) pred_pushfwd = pred_pushfwd[:, -self.conf.num_lookback_steps :, ...] else: + # Expected shape: (batch_size, ...) - no temporal dim pred_pushfwd = pred_pushfwd[:, -1, ...] # Pushforward for final step - pred = self.vmapped_pred_integrator( - pred_pushfwd, tspan[-2:], dict(params=params, **mutables) + # Expected shape: (batch_size, ...) - no temporal dim + pred = self.pred_integrator( + pred_pushfwd, batch["tspan"][-2:], dict(params=params, **mutables) )[:, -1, ...] # Computing losses. @@ -132,7 +137,8 @@ def loss_fn( l2 = jnp.mean(jnp.square(pred - true[:, -1, ...])) else: # Regular unrolling without stop-gradient - pred = self.vmapped_pred_integrator( + # Expected shape: (bsz, num_rollout_steps, ...) + pred = self.pred_integrator( x0, tspan, dict(params=params, **mutables) )[:, self.conf.num_lookback_steps :, ...] measure_dist = jnp.mean( @@ -179,15 +185,15 @@ def eval_fn( rng: jax.random.KeyArray, **kwargs, ) -> models.ArrayDict: - tspan = batch["tspan"].reshape((-1,)) - - pred_trajs = self.vmapped_pred_integrator(batch["ic"], tspan, variables)[ + # Keep extra step for plot fns + pred_trajs = self.pred_integrator(batch["ic"], tspan, variables)[ :, self.conf.num_lookback_steps - 1 :, ... - ] # Keep extra step for plot fns + ] trajs = batch["true"] if ( - self.conf.normalize_stats["mean"] is not None + self.conf.normalize_stats is not None + and self.conf.normalize_stats["mean"] is not None and self.conf.normalize_stats["std"] is not None ): trajs *= self.conf.normalize_stats["std"] @@ -197,8 +203,7 @@ def eval_fn( # TODO(lzepedanunez): this only computes the local sinkhorn distance. sd = measure_distances.sinkhorn_div( - pred_trajs[:, 1:, ...], - trajs[:, 1:, ...] + pred_trajs[:, -1, ...], trajs[:, -1, ...] ) dt = tspan[1] - tspan[0] return dict( @@ -207,6 +212,7 @@ def eval_fn( trajs=trajs, pred_trajs=pred_trajs, ) + # pytype: enable=bad-return-type @@ -274,17 +280,26 @@ def _preprocess_train_batch( # x0, except when num_lookback_steps > 1, where true[:, 0] corresponds to # the last time step in x0). if self.conf.num_lookback_steps > 1: + # Expected shape: + # - Lorenz: (bsz, num_lookback_steps, 3) + # - KS: (bsz, num_lookback_steps, spatial_dim, 1) + # - NS: (bsz, num_lookback_steps, spatial_dim, spatial_dim, 1) x0 = batch_data["u"][:, : self.conf.num_lookback_steps, ...] + # Expected shape: (bsz, num_rollout_steps + 1, ...), + # where ... is same as just above. true = batch_data["u"][ :, - self.conf.num_lookback_steps - - 1 : num_time_steps - + self.conf.num_lookback_steps - - 1, # pylint: disable=line-too-long + self.conf.num_lookback_steps - 1 : num_time_steps + self.conf.num_lookback_steps - 1, # pylint: disable=line-too-long ..., ] else: + # Expected shape (no temporal dim): + # - Lorenz: (bsz, 3) + # - KS: (bsz, spatial_dim, 1) + # - NS: (bsz, spatial_dim, spatial_dim, 1) x0 = batch_data["u"][:, 0, ...] + # Expected shape: (bsz, num_rollout_steps + 1, ...), + # where ... is same as just above. true = batch_data["u"][:, :num_time_steps, ...] return dict( x0=x0, @@ -319,6 +334,12 @@ def preprocess_train_batch( num_time_steps = jax.random.randint( rng, (1,), minval=2, maxval=num_time_steps + 1 )[0] + # pytype: disable=attribute-error + assert num_time_steps <= batch_data["u"].shape[1], ( + f"Not enough time steps in data ({batch_data.shape[1]}) for desired" + f" steps ({num_time_steps})." + ) + # pytype: enable=attribute-error return self._preprocess_train_batch(batch_data, num_time_steps) def preprocess_eval_batch(