Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 565226900
  • Loading branch information
The swirl_dynamics Authors committed Sep 14, 2023
1 parent 2b5cd84 commit e3040a6
Show file tree
Hide file tree
Showing 11 changed files with 465 additions and 209 deletions.
66 changes: 34 additions & 32 deletions swirl_dynamics/lib/solvers/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@
# limitations under the License.

"""Solvers for ordinary differential equations (ODEs)."""

from typing import Any, Protocol

import flax
Expand Down Expand Up @@ -60,7 +59,7 @@ def step(
self, func: OdeDynamics, x0: Array, t0: Array, dt: Array, params: PyTree
) -> Array:
"""Advances the current state one step forward in time."""
raise NotImplementedError
raise NotImplementedError("Scan solver must implement `step` method.")

def __call__(
self, func: OdeDynamics, x0: Array, tspan: Array, params: PyTree
Expand Down Expand Up @@ -150,45 +149,43 @@ def __call__(
)


@flax.struct.dataclass
class MultiStepScanOdeSolver:
"""ODE solver based on `jax.lax.scan` that uses one than one time step.
Rather than x_{n+1} = f(x_n, t_n), we have
x_{n+1} = f(x_{n-k}, x_{n-k+1}, ..., x_{n-1}, x_n, t_{n-k}, ..., t_{n-1}, t_n)
for some 'num_lookback_steps' window k.
Attributes:
time_axis_pos: move the time axis to the specified position in the output
tensor (by default it is at the 0th position). This attribute is used to
both indicate the temporal axis position of the input and where to place
the temporal axis of the output.
"""

@staticmethod
def stack_timesteps_along_channel_dim(x: Array) -> Array:
time_axis_pos: int = 0

def stack_timesteps_along_channel_dim(self, x: Array) -> Array:
"""Helper method to package batches for multi-step solvers.
Args:
x: Array of shape: (num_lookback_steps, state_dims, channels), where
state_dims can have ndim >= 1
x: Array of containing axes for batch_size (potentially),
lookback_steps (e.g., temporal axis), spatial_dims, and channels, where
spatial_dims can have ndim >= 1
Returns:
Array of shape (state_dims, num_lookback_steps*channels)
Array where each time step in the temporal dim is concatenated along
the channel axis
"""
orig_shape = x.shape
num_lookback_steps = orig_shape[0]
stacked_shape = list(orig_shape)
# All the previous timesteps are collapsed into one
stacked_shape[0] = 1
# Concatenate steps along channel dim
stacked_shape[-1] *= num_lookback_steps
# For state_dims with ndim > 1, e.g. 2D grids, flatten state dims:
flattened_state_size = np.prod(list(x.shape[1:-1]))
x_flattened_state_shape = (
(x.shape[0],) + (flattened_state_size,) + (x.shape[-1],)
)
x = x.reshape(x_flattened_state_shape)
return x.swapaxes(0, 1).reshape(stacked_shape).squeeze(axis=0)
x = jnp.moveaxis(x, self.time_axis_pos, -2)
return jnp.reshape(x, x.shape[:-2] + (-1,))

def step(
self, func: OdeDynamics, x0: Array, t0: Array, dt: Array, params: PyTree
) -> Array:
"""Advances the current state one step forward in time."""
raise NotImplementedError
raise NotImplementedError("Scan solver must implement `step` method.")

def __call__(
self,
Expand All @@ -202,21 +199,26 @@ def __call__(
def scan_fun(
state: tuple[Array, Array], t_next: Array
) -> tuple[tuple[Array, Array], Array]:
# x0 assumed to have shape: (lookback, state_dims, channels)
# Expected dimension for x0 is either:
# - (t, ...), if time_axis_pos == 0
# - (batch_size, ..., t, ...), if time_axis_pos > 0
x0, t0 = state
x0_stack = self.stack_timesteps_along_channel_dim(x0)[None, ...]
# input to func has shape: (state_dims, channels*lookback)
x0_stack = self.stack_timesteps_along_channel_dim(x0)
dt = t_next - t0
# return item (x_next) has shape: state_dims x channels
x_next = self.step(func, x0_stack, t0, dt, params)
# carry item has same shape as x0, where we first shift over the original
# input and append the new predicted state along the time dimension
x_carry = jnp.concatenate([x0[1:, ...], x_next], axis=0)
return (x_carry, t_next), x_next.squeeze(axis=0)
x_next = jnp.expand_dims(x_next, axis=self.time_axis_pos)
x_prev = jnp.take(
x0, np.arange(1, x0.shape[self.time_axis_pos]),
axis=self.time_axis_pos
)
x_carry = jnp.concatenate([x_prev, x_next], axis=self.time_axis_pos)
return (x_carry, t_next), x_next.squeeze(axis=self.time_axis_pos)

_, out = jax.lax.scan(scan_fun, (x0, tspan[0]), tspan[1:])
# output of scan has shape: len(tspan) - 1 x state_dims x channels
return jnp.concatenate([x0, out], axis=0)
return jnp.concatenate(
[x0, jnp.moveaxis(out, 0, self.time_axis_pos)],
axis=self.time_axis_pos
)


class MultiStepDirect(MultiStepScanOdeSolver):
Expand Down
146 changes: 131 additions & 15 deletions swirl_dynamics/lib/solvers/ode_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import functools

from absl.testing import absltest
from absl.testing import parameterized
import jax
Expand All @@ -25,6 +27,12 @@ def dummy_ode_dynamics(x, t, params):
return jnp.ones_like(x)


def dummy_ode_dynamics_return_original_num_channel(x, t, params, num_channels):
"""Assumes ode dynamics returns channel dim of 1."""
del t, params
return jnp.ones(x.shape[:-1] + (num_channels,))


class OdeSolversTest(parameterized.TestCase):

@parameterized.parameters(
Expand All @@ -46,18 +54,55 @@ def test_output_shape_and_value(self, solver, backward):
self.assertEqual(out.shape, (num_steps, x_dim))
np.testing.assert_allclose(out[-1], np.ones((x_dim,)) * tspan[-1])

def test_move_time_axis_pos(self):
def test_output_shape_and_value_one_step_direct(self):
dt = 0.1
num_steps = 10
x_dim = 5
tspan = jnp.arange(num_steps) * dt
solver = ode.OneStepDirect()
out = solver(dummy_ode_dynamics, jnp.zeros((x_dim,)), tspan, {})
self.assertEqual(out.shape, (num_steps, x_dim))
np.testing.assert_array_equal(out[-1], np.ones((x_dim,)))

@parameterized.parameters(
{"time_axis_pos": 1, "x_dim": (5,)},
{"time_axis_pos": 2, "x_dim": (5, 5)},
)
def test_move_time_axis_pos(self, time_axis_pos, x_dim):
dt = 0.1
num_steps = 10
batch_sz = 6
input_shape = (batch_sz,) + x_dim
tspan = jnp.arange(num_steps) * dt
out = ode.ExplicitEuler(time_axis_pos=1)(
dummy_ode_dynamics, jnp.zeros((batch_sz, x_dim)), tspan, {}
solver = ode.ExplicitEuler(time_axis_pos=time_axis_pos)
expected_output_shape = (
input_shape[:time_axis_pos] + (num_steps,) + input_shape[time_axis_pos:]
)
self.assertEqual(out.shape, (batch_sz, num_steps, x_dim))
out = solver(dummy_ode_dynamics, jnp.zeros(input_shape), tspan, {})
self.assertEqual(out.shape, expected_output_shape)
np.testing.assert_allclose(
out[:, -1], np.ones((batch_sz, x_dim)) * tspan[-1]
jnp.moveaxis(out, time_axis_pos, 1)[:, -1],
np.ones(input_shape) * tspan[-1],
)

@parameterized.parameters(
{"time_axis_pos": 1, "x_dim": (5,)},
{"time_axis_pos": 2, "x_dim": (5, 5)},
)
def test_move_time_axis_pos_one_step_direct(self, time_axis_pos, x_dim):
dt = 0.1
num_steps = 10
batch_sz = 6
input_shape = (batch_sz,) + x_dim
tspan = jnp.arange(num_steps) * dt
solver = ode.OneStepDirect(time_axis_pos=time_axis_pos)
out = solver(dummy_ode_dynamics, jnp.zeros(input_shape), tspan, {})
expected_output_shape = (
input_shape[:time_axis_pos] + (num_steps,) + input_shape[time_axis_pos:]
)
self.assertEqual(out.shape, expected_output_shape)
np.testing.assert_array_equal(
jnp.moveaxis(out, time_axis_pos, 1)[:, -1], np.ones(input_shape)
)

@parameterized.parameters((np.arange(10) * -1,), (np.zeros(10),))
Expand All @@ -70,37 +115,108 @@ def test_dopri45_backward_error(self, tspan):
class MultiStepOdeSolversTest(parameterized.TestCase):

@parameterized.product(
time_axis_pos=(0, 1, 2),
batch_size=(1, 8),
state_dim=((512,), (64, 64), (32, 32, 32)),
channels=(1, 2, 3),
num_lookback_steps=(2, 4, 8),
)
def test_stacked_output_shape_and_value(
self,
time_axis_pos,
batch_size,
state_dim,
channels,
num_lookback_steps,
):
# Setup initial shape with time in axis 0
input_shape = [num_lookback_steps, batch_size]
input_shape.extend(state_dim)
input_shape.append(channels)
# Re-order shape to match time_axis_pos (for time_axis_pos=0, this is no OP)
input_shape[0] = input_shape[time_axis_pos]
input_shape[time_axis_pos] = num_lookback_steps

rng = jax.random.PRNGKey(0)
input_shape = (num_lookback_steps,) + state_dim + (channels,)
input_state = jax.random.normal(rng, input_shape)
input_state_stacked = (
ode.MultiStepScanOdeSolver.stack_timesteps_along_channel_dim(
input_state
)
input_state_stacked = ode.MultiStepScanOdeSolver(
time_axis_pos=time_axis_pos
).stack_timesteps_along_channel_dim(input_state)
expected_output_shape = (
input_shape[:time_axis_pos]
+ input_shape[time_axis_pos + 1 : -1]
+ [channels * num_lookback_steps]
)
# Check that expected shapes match
self.assertEqual(
input_state_stacked.shape, state_dim + (channels * num_lookback_steps,)
)
# Check that timesteps were correctly concatenated along channel dim
self.assertEqual(input_state_stacked.shape, tuple(expected_output_shape))
# Check that timesteps correctly concatenated along channel dim
for w in range(num_lookback_steps):
c_start = channels * w
c_end = channels * (w + 1)
np.testing.assert_array_equal(
input_state[w, ...],
jnp.moveaxis(input_state, time_axis_pos, 0)[w, ...],
input_state_stacked[..., c_start:c_end],
)

@parameterized.product(
time_axis_pos=(0, 1, 2),
batch_size=(1, 8),
state_dim=((512,), (64, 64), (32, 32, 32)),
channels=(1, 2, 3),
num_lookback_steps=(2, 4, 8),
)
def test_output_shape_and_value_multi_step_direct(
self,
time_axis_pos,
batch_size,
state_dim,
channels,
num_lookback_steps,
):
# Setup initial shape with time in axis 0
input_shape = [num_lookback_steps, batch_size]
input_shape.extend(state_dim)
input_shape.append(channels)
# Re-order shape to match time_axis_pos (for time_axis_pos=0, this is no OP)
input_shape[0] = input_shape[time_axis_pos]
input_shape[time_axis_pos] = num_lookback_steps
dt = 0.1
num_steps = 10
tspan = jnp.arange(num_steps) * dt
solver = ode.MultiStepDirect(time_axis_pos=time_axis_pos)
out = solver(
functools.partial(
dummy_ode_dynamics_return_original_num_channel,
num_channels=channels,
),
jnp.zeros(input_shape),
tspan,
{},
)
expected_output_shape = (
input_shape[:time_axis_pos]
+ [num_steps]
+ input_shape[time_axis_pos + 1 :]
)
out = jnp.take(
out,
np.arange(num_lookback_steps - 1, out.shape[time_axis_pos]),
axis=time_axis_pos,
)
self.assertEqual(out.shape, tuple(expected_output_shape))
np.testing.assert_array_equal(
jnp.take(
out,
np.arange(1, out.shape[time_axis_pos]),
axis=time_axis_pos,
),
jnp.take(
np.ones(expected_output_shape),
np.arange(1, out.shape[time_axis_pos]),
axis=time_axis_pos,
),
)


if __name__ == "__main__":
absltest.main()
22 changes: 16 additions & 6 deletions swirl_dynamics/projects/ergodic/choices.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,17 +59,19 @@ class Integrator(enum.Enum):

def dispatch(
self,
) -> type[ode.ScanOdeSolver] | type[ode.MultiStepScanOdeSolver]:
) -> ode.ScanOdeSolver | ode.MultiStepScanOdeSolver:
"""Dispatch integator.
Returns:
ScanOdeSolver | MultiStepScanOdeSolver
"""
# TODO(yairschiff): Profile if the moveaxis call required here introduces a
# bottleneck
return {
"ExplicitEuler": ode.ExplicitEuler,
"RungeKutta4": ode.RungeKutta4,
"OneStepDirect": ode.OneStepDirect,
"MultiStepDirect": ode.MultiStepDirect,
"ExplicitEuler": ode.ExplicitEuler(time_axis_pos=1),
"RungeKutta4": ode.RungeKutta4(time_axis_pos=1),
"OneStepDirect": ode.OneStepDirect(time_axis_pos=1),
"MultiStepDirect": ode.MultiStepDirect(time_axis_pos=1),
}[self.value]


Expand Down Expand Up @@ -109,7 +111,8 @@ def dispatch(
class Model(enum.Enum):
"""Model choices."""

FNO = "FNO"
FNO = "Fno"
FNO_2D = "Fno2d"
MLP = "MLP"
PERIODIC_CONV_NET_MODEL = "PeriodicConvNetModel"

Expand Down Expand Up @@ -148,4 +151,11 @@ def dispatch(self, conf: ml_collections.ConfigDict) -> nn.Module:
fft_norm=conf.fft_norm,
separable=conf.separable,
)
if self.value == Model.FNO_2D.value:
return fno.Fno2d(
out_channels=conf.out_channels,
num_modes=conf.num_modes,
width=conf.width,
fft_norm=conf.fft_norm
)
raise ValueError()
Loading

0 comments on commit e3040a6

Please sign in to comment.