Skip to content

Commit

Permalink
Code update
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 561981580
  • Loading branch information
zhong1wan authored and The swirl_dynamics Authors committed Sep 1, 2023
1 parent 74ef759 commit 1fe9d36
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 2 deletions.
16 changes: 14 additions & 2 deletions swirl_dynamics/lib/solvers/ode.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

"""Solvers for ordinary differential equations (ODEs)."""

import dataclasses
from typing import Any, Protocol

import flax
Expand Down Expand Up @@ -45,8 +46,16 @@ def __call__(
...


@dataclasses.dataclass
class ScanOdeSolver:
"""ODE solver based on `jax.lax.scan`."""
"""ODE solver based on `jax.lax.scan`.
Attributes:
time_axis_pos: move the time axis to the specified position in the output
tensor (by default it is at the 0th position).
"""

time_axis_pos: int = 0

def step(
self, func: OdeDynamics, x0: Array, t0: Array, dt: Array, params: PyTree
Expand All @@ -68,7 +77,10 @@ def scan_fun(
return (x_next, t_next), x_next

_, out = jax.lax.scan(scan_fun, (x0, tspan[0]), tspan[1:])
return jnp.concatenate([x0[None], out], axis=0)
out = jnp.concatenate([x0[None], out], axis=0)
if self.time_axis_pos:
out = jnp.moveaxis(out, 0, self.time_axis_pos)
return out


class ExplicitEuler(ScanOdeSolver):
Expand Down
14 changes: 14 additions & 0 deletions swirl_dynamics/lib/solvers/ode_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,20 @@ def test_output_shape_and_value(self, solver, backward):
self.assertEqual(out.shape, (num_steps, x_dim))
np.testing.assert_allclose(out[-1], np.ones((x_dim,)) * tspan[-1])

def test_move_time_axis_pos(self):
dt = 0.1
num_steps = 10
x_dim = 5
batch_sz = 6
tspan = jnp.arange(num_steps) * dt
out = ode.ExplicitEuler(time_axis_pos=1)(
dummy_ode_dynamics, jnp.zeros((batch_sz, x_dim)), tspan, {}
)
self.assertEqual(out.shape, (batch_sz, num_steps, x_dim))
np.testing.assert_allclose(
out[:, -1], np.ones((batch_sz, x_dim)) * tspan[-1]
)

@parameterized.parameters((np.arange(10) * -1,), (np.zeros(10),))
def test_dopri45_backward_error(self, tspan):
tspan = jnp.asarray(tspan)
Expand Down

0 comments on commit 1fe9d36

Please sign in to comment.