From 6263541e54093f0025879c4abfbf62502a4efbea Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 29 Oct 2024 09:25:52 +0000 Subject: [PATCH 1/8] Convert Solver.flatten to a classmethod --- bt_ocean/model.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/bt_ocean/model.py b/bt_ocean/model.py index f965ed3..9d1b811 100644 --- a/bt_ocean/model.py +++ b/bt_ocean/model.py @@ -791,14 +791,14 @@ def flatten(self): """ return ((dict(self.fields), dict(self.dealias_fields), self.n), - (type(self), self.parameters, self.poisson_solver)) + (self.parameters, self.poisson_solver)) - @staticmethod - def unflatten(aux_data, children): + @classmethod + def unflatten(cls, aux_data, children): """Unpack a JAX flattened representation. """ - cls, parameters, poisson_solver = aux_data + parameters, poisson_solver = aux_data fields, dealias_fields, n = children solver = cls(parameters) @@ -951,8 +951,8 @@ def flatten(self): children, aux_data = super().flatten() return children, aux_data + (self.modified_helmholtz_solver,) - @staticmethod - def unflatten(aux_data, children): - solver = Solver.unflatten(aux_data[:-1], children) + @classmethod + def unflatten(cls, aux_data, children): + solver = super().unflatten(aux_data[:-1], children) solver.modified_helmholtz_solver = aux_data[-1] return solver From f7addf057ed2847f89e395cfa09cb924fdfc8181 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 29 Oct 2024 09:34:42 +0000 Subject: [PATCH 2/8] Change Solver.steady_state_solve API --- bt_ocean/model.py | 40 +++++++++++------------ docs/source/examples/2_steady_state.ipynb | 2 +- 2 files changed, 21 insertions(+), 21 deletions(-) diff --git a/bt_ocean/model.py b/bt_ocean/model.py index 9d1b811..9be0000 100644 --- a/bt_ocean/model.py +++ b/bt_ocean/model.py @@ -593,7 +593,7 @@ def ke_spectrum(self, N_x, N_y): return (0.5 * self.grid.L_x * self.grid.L_y * (k ** 2 + l ** 2) * (dst(dst(psi, axis=1), axis=0) ** 2)) - def steady_state_solve(self, m=(), update=lambda model, *m: None, *, tol, max_it=10000, _min_n=0): + def steady_state_solve(self, update=lambda model, *args: None, *args, tol, max_it=10000, _min_n=0): r"""Timestep to steady-state. Uses timestepping to define a fixed-point iteration, and applies @@ -615,12 +615,12 @@ def steady_state_solve(self, m=(), update=lambda model, *m: None, *, tol, max_it Parameters ---------- - m : Sequence[:class:`jax.Array`, ...] - Additional control variables. update : callable A callable accepting a :class:`.Solver` as the zeroth argument and the elements of `m` as remaining positional arguments, and which updates the values of control variables. + args : tuple + Passed to `update`. tol : Real Tolerance. The system is timestepped until @@ -642,18 +642,18 @@ def steady_state_solve(self, m=(), update=lambda model, *m: None, *, tol, max_it """ @jax.jit - def forward_step(data, m): + def forward_step(data, args): _, model, it = data zeta_0 = model.fields["zeta"] - update(model, *m) + update(model, *args) model.step() return (zeta_0, model, it + 1) @jax.custom_vjp - def forward(model, m): + def forward(model, args): while model.n < _min_n: - zeta_0, model, _ = forward_step((None, model, 0), m) - zeta_0, model, _ = forward_step((None, model, 0), m) + zeta_0, model, _ = forward_step((None, model, 0), args) + zeta_0, model, _ = forward_step((None, model, 0), args) def non_convergence(data): zeta_0, model, it = data @@ -663,20 +663,20 @@ def non_convergence(data): abs(zeta_1 - zeta_0).max() > tol * abs(zeta_1).max()) _, model, it = jax.lax.while_loop( - non_convergence, partial(forward_step, m=m), + non_convergence, partial(forward_step, args=args), (zeta_0, model, 1)) return model, it - def forward_fwd(model, m): - model, it = forward(model, m) - return (model, it), (model, m) + def forward_fwd(model, args): + model, it = forward(model, args) + return (model, it), (model, args) def forward_bwd(res, zeta): - model, m = res + model, args = res zeta_model, _ = zeta _, vjp_step = jax.vjp( - lambda model: forward_step((None, model, 0), m)[1], model) + lambda model: forward_step((None, model, 0), args)[1], model) @jax.jit def adj_step(data, zeta_model): @@ -709,14 +709,14 @@ def non_convergence(data): return lam_model lam_model = adjoint(zeta_model) - _, vjp = jax.vjp(lambda m: forward_step((None, model, 0), m)[1], m) - lam_m, = vjp(lam_model) + _, vjp = jax.vjp(lambda args: forward_step((None, model, 0), args)[1], args) + lam_args, = vjp(lam_model) - return lam_model, lam_m + return lam_model, lam_args forward.defvjp(forward_fwd, forward_bwd) - model, it = forward(self, m) + model, it = forward(self, args) self.update(model) if it > max_it: raise SteadyStateMaximumIterationsError("Maximum number of iterations exceeded") @@ -939,8 +939,8 @@ def ke(self): v = self.dealias_fields["v"] return 0.5 * jnp.tensordot((u * u + v * v), self.dealias_grid.W) - def steady_state_solve(self, m=(), update=lambda model, *m: None, *, tol, max_it=10000): - return super().steady_state_solve(m=m, update=update, tol=tol, _min_n=1, max_it=max_it) + def steady_state_solve(self, m=(), update=lambda model, *args: None, *args, tol, max_it=10000): + return super().steady_state_solve(*args, update=update, tol=tol, _min_n=1, max_it=max_it) def new(self): solver = super().new() diff --git a/docs/source/examples/2_steady_state.ipynb b/docs/source/examples/2_steady_state.ipynb index ebb5790..e3f584e 100644 --- a/docs/source/examples/2_steady_state.ipynb +++ b/docs/source/examples/2_steady_state.ipynb @@ -180,7 +180,7 @@ " model.fields[\"Q\"] = Q - (r - model.r) * model.fields[\"zeta\"]\n", "\n", " model = CNAB2Solver(parameters)\n", - " model.steady_state_solve(m=(r,), update=update, tol=1.0e-6)\n", + " model.steady_state_solve(r, update=update, tol=1.0e-6)\n", " return D * model.fields[\"psi\"].max() / 1.0e6\n", "\n", "\n", From 3b3d24e5414ebd180058bff8947d292b6cd8a419 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 29 Oct 2024 09:42:38 +0000 Subject: [PATCH 3/8] Keras Solver serialization --- bt_ocean/model.py | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/bt_ocean/model.py b/bt_ocean/model.py index 9be0000..72546de 100644 --- a/bt_ocean/model.py +++ b/bt_ocean/model.py @@ -4,6 +4,7 @@ import jax import jax.numpy as jnp +import keras import numpy as np from abc import ABC, abstractmethod @@ -424,6 +425,7 @@ def unflatten(aux_data, children): jax.tree_util.register_pytree_node(cls, flatten, unflatten) Solver._registry[cls.__name__] = cls + keras.saving.register_keras_serializable(package=f"_bt_ocean__{cls.__name__}")(cls) @cached_property def beta(self) -> jax.Array: @@ -810,6 +812,23 @@ def unflatten(cls, aux_data, children): return solver + def get_config(self): + return {"type": type(self).__name__, + "parameters": dict(self.parameters), + "fields": dict(self.fields), + "dealias_fields": dict(self.dealias_fields), + "n": self.n} + + @classmethod + def from_config(cls, config): + config = {key: keras.saving.deserialize_keras_object(value) for key, value in config.items()} + cls = Solver._registry[config["type"]] + solver = cls(config["parameters"]) + solver.fields.update(config["fields"]) + solver.dealias_fields.update(config["dealias_fields"]) + solver.n = config["n"] + return solver + def read_solver(h, path="solver"): """Read solver from a :class:`zarr.hierarchy.Group`. From 4909b55d9ea814b2a7ac0a6da9065e81fe91bbf1 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 29 Oct 2024 09:53:25 +0000 Subject: [PATCH 4/8] Solver.steady_state_solve fix --- bt_ocean/model.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/bt_ocean/model.py b/bt_ocean/model.py index 72546de..2773de3 100644 --- a/bt_ocean/model.py +++ b/bt_ocean/model.py @@ -595,7 +595,7 @@ def ke_spectrum(self, N_x, N_y): return (0.5 * self.grid.L_x * self.grid.L_y * (k ** 2 + l ** 2) * (dst(dst(psi, axis=1), axis=0) ** 2)) - def steady_state_solve(self, update=lambda model, *args: None, *args, tol, max_it=10000, _min_n=0): + def steady_state_solve(self, *args, update=lambda model, *args: None, tol, max_it=10000, _min_n=0): r"""Timestep to steady-state. Uses timestepping to define a fixed-point iteration, and applies @@ -958,8 +958,8 @@ def ke(self): v = self.dealias_fields["v"] return 0.5 * jnp.tensordot((u * u + v * v), self.dealias_grid.W) - def steady_state_solve(self, m=(), update=lambda model, *args: None, *args, tol, max_it=10000): - return super().steady_state_solve(*args, update=update, tol=tol, _min_n=1, max_it=max_it) + def steady_state_solve(self, *args, update=lambda model, *args: None, tol, max_it=10000, _min_n=0): + return super().steady_state_solve(*args, update=update, tol=tol, max_it=max_it, _min_n=1) def new(self): solver = super().new() From 906ad1745eb60af34582a1750c6b0305cacefa60 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 29 Oct 2024 10:13:06 +0000 Subject: [PATCH 5/8] Solver -> cls --- bt_ocean/model.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/bt_ocean/model.py b/bt_ocean/model.py index 2773de3..dda44bf 100644 --- a/bt_ocean/model.py +++ b/bt_ocean/model.py @@ -424,7 +424,7 @@ def unflatten(aux_data, children): jax.tree_util.register_pytree_node(cls, flatten, unflatten) - Solver._registry[cls.__name__] = cls + cls._registry[cls.__name__] = cls keras.saving.register_keras_serializable(package=f"_bt_ocean__{cls.__name__}")(cls) @cached_property @@ -822,7 +822,7 @@ def get_config(self): @classmethod def from_config(cls, config): config = {key: keras.saving.deserialize_keras_object(value) for key, value in config.items()} - cls = Solver._registry[config["type"]] + cls = cls._registry[config["type"]] solver = cls(config["parameters"]) solver.fields.update(config["fields"]) solver.dealias_fields.update(config["dealias_fields"]) From 09f64dbd4d5444bdf0d1ede05237c58dec585418 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 29 Oct 2024 10:19:45 +0000 Subject: [PATCH 6/8] Convert read methods into class methods --- bt_ocean/model.py | 167 +++++++++++++++++++++++----------------------- tests/test_io.py | 9 ++- 2 files changed, 86 insertions(+), 90 deletions(-) diff --git a/bt_ocean/model.py b/bt_ocean/model.py index dda44bf..7314817 100644 --- a/bt_ocean/model.py +++ b/bt_ocean/model.py @@ -19,16 +19,13 @@ [ "Parameters", "required", - "read_parameters", "Fields", - "read_fields", "SteadyStateMaximumIterationsError", "NanEncounteredError", "Solver", - "CNAB2Solver", - "read_solver" + "CNAB2Solver" ] @@ -94,26 +91,26 @@ def write(self, h, path="parameters"): g.attrs.update(self.items()) return g + @classmethod + def read(cls, h, path="parameters"): + """Read parameters from a :class:`zarr.hierarchy.Group`. -def read_parameters(h, path="parameters"): - """Read parameters from a :class:`zarr.hierarchy.Group`. - - Parameters - ---------- + Parameters + ---------- - h : :class:`zarr.hierarchy.Group` - Parent group. - path : str - Group path. + h : :class:`zarr.hierarchy.Group` + Parent group. + path : str + Group path. - Returns - ------- + Returns + ------- - :class:`.Parameters` - The parameters. - """ + :class:`.Parameters` + The parameters. + """ - return Parameters(h[path].attrs) + return cls(h[path].attrs) class Fields(Mapping): @@ -255,50 +252,50 @@ def update(self, d): for key, value in d.items(): self[key] = value + @classmethod + def read(cls, h, path="fields", *, grid=None): + """Read fields from a :class:`zarr.hierarchy.Group`. -def read_fields(h, path="fields", *, grid=None): - """Read fields from a :class:`zarr.hierarchy.Group`. - - Parameters - ---------- + Parameters + ---------- - h : :class:`zarr.hierarchy.Group` - Parent group. - path : str - Group path. - grid : :class:`.Grid` - The 2D Chebyshev grid. + h : :class:`zarr.hierarchy.Group` + Parent group. + path : str + Group path. + grid : :class:`.Grid` + The 2D Chebyshev grid. - Returns - ------- + Returns + ------- - :class:`.Fields` - The fields. - """ + :class:`.Fields` + The fields. + """ - g = h[path] - del h + g = h[path] + del h - L_x = g.attrs["L_x"] - L_y = g.attrs["L_y"] - N_x = g.attrs["N_x"] - N_y = g.attrs["N_y"] - idtype = jnp.dtype(g.attrs["idtype"]).type - fdtype = jnp.dtype(g.attrs["fdtype"]).type - if grid is None: - grid = Grid(L_x, L_y, N_x, N_y, idtype=idtype, fdtype=fdtype) - if L_x != grid.L_x or L_y != grid.L_y: - raise ValueError("Invalid dimension(s)") - if N_x != grid.N_x or N_y != grid.N_y: - raise ValueError("Invalid degree(s)") - if idtype != grid.idtype or fdtype != grid.fdtype: - raise ValueError("Invalid dtype(s)") + L_x = g.attrs["L_x"] + L_y = g.attrs["L_y"] + N_x = g.attrs["N_x"] + N_y = g.attrs["N_y"] + idtype = jnp.dtype(g.attrs["idtype"]).type + fdtype = jnp.dtype(g.attrs["fdtype"]).type + if grid is None: + grid = Grid(L_x, L_y, N_x, N_y, idtype=idtype, fdtype=fdtype) + if L_x != grid.L_x or L_y != grid.L_y: + raise ValueError("Invalid dimension(s)") + if N_x != grid.N_x or N_y != grid.N_y: + raise ValueError("Invalid degree(s)") + if idtype != grid.idtype or fdtype != grid.fdtype: + raise ValueError("Invalid dtype(s)") - fields = Fields(grid, set(g)) - for key in g: - fields[key] = g[key][...] + fields = cls(grid, set(g)) + for key in g: + fields[key] = g[key][...] - return fields + return fields class SteadyStateMaximumIterationsError(Exception): @@ -754,6 +751,36 @@ def write(self, h, path="solver"): return g + @classmethod + def read(cls, h, path="solver"): + """Read solver from a :class:`zarr.hierarchy.Group`. + + Parameters + ---------- + + h : :class:`zarr.hierarchy.Group` + Parent group. + path : str + Group path. + + Returns + ------- + + :class:`.Solver` + The solver. + """ + + g = h[path] + del h + + cls = cls._registry[g.attrs["type"]] + solver = cls(Parameters.read(g, "parameters")) + solver.fields.update(Fields.read(g, "fields", grid=solver.grid)) + solver.dealias_fields.update(Fields.read(g, "dealias_fields", grid=solver.dealias_grid)) + solver.n = g.attrs["n"] + + return solver + def new(self): """Return a new :class:`.Solver` with the same configuration as this :class:`.Solver`. @@ -830,36 +857,6 @@ def from_config(cls, config): return solver -def read_solver(h, path="solver"): - """Read solver from a :class:`zarr.hierarchy.Group`. - - Parameters - ---------- - - h : :class:`zarr.hierarchy.Group` - Parent group. - path : str - Group path. - - Returns - ------- - - :class:`.Solver` - The solver. - """ - - g = h[path] - del h - - cls = Solver._registry[g.attrs["type"]] - solver = cls(read_parameters(g, "parameters")) - solver.fields.update(read_fields(g, "fields", grid=solver.grid)) - solver.dealias_fields.update(read_fields(g, "dealias_fields", grid=solver.dealias_grid)) - solver.n = g.attrs["n"] - - return solver - - class CNAB2Solver(Solver): """Chebyshev pseudospectral solver for the 2D barotropic vorticity equation on a beta-plane, using a CNAB2 time discretization. diff --git a/tests/test_io.py b/tests/test_io.py index 0f663e8..fd2b98d 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -1,6 +1,5 @@ from bt_ocean.grid import Grid -from bt_ocean.model import ( - CNAB2Solver, Fields, Parameters, read_fields, read_parameters, read_solver) +from bt_ocean.model import CNAB2Solver, Fields, Parameters, Solver from bt_ocean.parameters import parameters import jax.numpy as jnp @@ -25,7 +24,7 @@ def test_parameters_roundtrip(tmp_path): with zarr.open(filename, "w") as h: parameters.write(h) with zarr.open(filename, "r") as h: - input_parameters = read_parameters(h) + input_parameters = Parameters.read(h) assert set(input_parameters) == set(parameters) for key, value in parameters.items(): @@ -47,7 +46,7 @@ def test_fields_roundtrip(tmp_path): with zarr.open(filename, "w") as h: fields.write(h) with zarr.open(filename, "r") as h: - input_fields = read_fields(h) + input_fields = Fields.read(h) assert input_fields.grid.L_x == L_x assert input_fields.grid.L_y == L_y @@ -73,7 +72,7 @@ def test_solver_roundtrip(tmp_path): with zarr.open(filename, "w") as h: solver.write(h) with zarr.open(filename, "r") as h: - input_solver = read_solver(h) + input_solver = Solver.read(h) assert type(input_solver) is type(solver) assert input_solver.n == solver.n From e94b2457851074b39f601b5de9f802c7418bb3e1 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 29 Oct 2024 11:33:13 +0000 Subject: [PATCH 7/8] Dynamics layer serialization. solver -> model. --- bt_ocean/model.py | 64 +++++----- bt_ocean/network.py | 110 ++++++++++++------ .../source/examples/1_keras_integration.ipynb | 11 +- tests/test_io.py | 70 +++++------ tests/test_network.py | 76 +++++++++++- 5 files changed, 221 insertions(+), 110 deletions(-) diff --git a/bt_ocean/model.py b/bt_ocean/model.py index 7314817..1d5678b 100644 --- a/bt_ocean/model.py +++ b/bt_ocean/model.py @@ -413,8 +413,8 @@ def __init__(self, parameters, *, field_keys=None, dealias_field_keys=None, def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - def flatten(solver): - return solver.flatten() + def flatten(model): + return model.flatten() def unflatten(aux_data, children): return cls.unflatten(aux_data, children) @@ -774,12 +774,12 @@ def read(cls, h, path="solver"): del h cls = cls._registry[g.attrs["type"]] - solver = cls(Parameters.read(g, "parameters")) - solver.fields.update(Fields.read(g, "fields", grid=solver.grid)) - solver.dealias_fields.update(Fields.read(g, "dealias_fields", grid=solver.dealias_grid)) - solver.n = g.attrs["n"] + model = cls(Parameters.read(g, "parameters")) + model.fields.update(Fields.read(g, "fields", grid=model.grid)) + model.dealias_fields.update(Fields.read(g, "dealias_fields", grid=model.dealias_grid)) + model.n = g.attrs["n"] - return solver + return model def new(self): """Return a new :class:`.Solver` with the same configuration as this @@ -792,23 +792,23 @@ def new(self): The new :class:`.Solver`. """ - solver = type(self)(self.parameters) - solver.poisson_solver = self.poisson_solver - return solver + model = type(self)(self.parameters) + model.poisson_solver = self.poisson_solver + return model - def update(self, solver): + def update(self, model): """Update the state of this :class:`.Solver`. Parameters ---------- - solver : :class:`.Solver` + model : :class:`.Solver` Defines the new state of this :class:`.Solver`. """ - self.fields.update(solver.fields) - self.dealias_fields.update(solver.dealias_fields) - self.n = solver.n + self.fields.update(model.fields) + self.dealias_fields.update(model.dealias_fields) + self.n = model.n def flatten(self): """Return a JAX flattened representation. @@ -830,14 +830,14 @@ def unflatten(cls, aux_data, children): parameters, poisson_solver = aux_data fields, dealias_fields, n = children - solver = cls(parameters) - solver.poisson_solver = poisson_solver - solver.fields.update({key: value for key, value in fields.items() if type(value) is not object}) - solver.dealias_fields.update({key: value for key, value in dealias_fields.items() if type(value) is not object}) + model = cls(parameters) + model.poisson_solver = poisson_solver + model.fields.update({key: value for key, value in fields.items() if type(value) is not object}) + model.dealias_fields.update({key: value for key, value in dealias_fields.items() if type(value) is not object}) if type(n) is not object: - solver.n = n + model.n = n - return solver + return model def get_config(self): return {"type": type(self).__name__, @@ -850,11 +850,11 @@ def get_config(self): def from_config(cls, config): config = {key: keras.saving.deserialize_keras_object(value) for key, value in config.items()} cls = cls._registry[config["type"]] - solver = cls(config["parameters"]) - solver.fields.update(config["fields"]) - solver.dealias_fields.update(config["dealias_fields"]) - solver.n = config["n"] - return solver + model = cls(config["parameters"]) + model.fields.update(config["fields"]) + model.dealias_fields.update(config["dealias_fields"]) + model.n = config["n"] + return model class CNAB2Solver(Solver): @@ -959,9 +959,9 @@ def steady_state_solve(self, *args, update=lambda model, *args: None, tol, max_i return super().steady_state_solve(*args, update=update, tol=tol, max_it=max_it, _min_n=1) def new(self): - solver = super().new() - solver.modified_helmholtz_solver = self.modified_helmholtz_solver - return solver + model = super().new() + model.modified_helmholtz_solver = self.modified_helmholtz_solver + return model def flatten(self): children, aux_data = super().flatten() @@ -969,6 +969,6 @@ def flatten(self): @classmethod def unflatten(cls, aux_data, children): - solver = super().unflatten(aux_data[:-1], children) - solver.modified_helmholtz_solver = aux_data[-1] - return solver + model = super().unflatten(aux_data[:-1], children) + model.modified_helmholtz_solver = aux_data[-1] + return model diff --git a/bt_ocean/network.py b/bt_ocean/network.py index bd2415b..aa3bb68 100644 --- a/bt_ocean/network.py +++ b/bt_ocean/network.py @@ -19,7 +19,16 @@ ] -@keras.saving.register_keras_serializable(package="bt_ocean_scale") +def update_config(config, d): + config = dict(config) + for key, value in d.items(): + if key in config: + raise KeyError(f"key '{key}' already defined") + config[key] = value + return config + + +@keras.saving.register_keras_serializable(package="_bt_ocean__Scale") class Scale(keras.layers.Layer): """A layer which multiplies by a constant trainable weight. """ @@ -35,7 +44,7 @@ def build(self, input_shape): pass -@keras.saving.register_keras_serializable(package="bt_ocean_kronecker_product") +@keras.saving.register_keras_serializable(package="_bt_ocean__KroneckerProduct") class KroneckerProduct(keras.layers.Layer): """A layer where the weights matrix has Kronecker product structure. @@ -86,22 +95,6 @@ def call(self, inputs): outputs = self.__activation(outputs) return outputs - def get_config(self): - def update(config, d): - config = dict(config) - for key, value in d.items(): - if key in config: - raise KeyError(f"key '{key}' already defined") - config[key] = value - return config - - return update(super().get_config(), - {"shape_a": self.__shape_a, - "shape_b": self.__shape_b, - "activation": self.__activation_arg, - "symmetric": self.__symmetric, - "bias": self.__bias}) - def build(self, input_shape): if tuple(input_shape)[-2:] != self.__shape_a: raise ValueError("Invalid shape") @@ -186,23 +179,20 @@ def kronecker_product_network( return keras.models.Model(inputs=input_layer, outputs=output_layer) +@keras.saving.register_keras_serializable(package="_bt_ocean__Dynamics") class Dynamics(keras.layers.Layer): - """Defines a layer consisting of a dynamical core with a neural network - parameterized forcing. + """Defines a layer consisting of a dynamical solver. Parameters ---------- dynamics : :class:`.Solver` - The dynamical core. - Q_0 : :class:`jax.numpy.Array` - Wind forcing term in the vorticity equation. - Q_network - The right-hand-side forcing neural network. - Q_callback : callable - Passed `dynamics` and `Q_network`, and should return an additional term - to be added to the right-hand-side of the vorticity equation. Evaluated - before taking each timestep. + The dynamical solver. + update : callable + Passed `dynamics` and any arguments defined by `args`, and should + update the state of `dynamics`. Evaluated before taking each timestep. + args : tuple + Passed as remaining arguments to `update`. N : Integral The number of timesteps to take using the dynamical solver between each output. @@ -214,25 +204,51 @@ class Dynamics(keras.layers.Layer): Weight by which to scale each output. """ - def __init__(self, dynamics, Q_0, Q_network, Q_callback, N, *, n_output=1, - input_weight=1, output_weight=1): - super().__init__(dtype=dynamics.grid.fdtype) + _update_registry = {} + + def __init__(self, dynamics, update, *args, N=1, n_output=1, + input_weight=1, output_weight=1, **kwargs): + if "dtype" not in kwargs: + kwargs["dtype"] = dynamics.grid.fdtype + super().__init__(**kwargs) self.__dynamics = dynamics - self.__Q_0 = Q_0 - self.__Q_network = Q_network - self.__Q_callback = Q_callback + self.__update = update + self.__args = args self.__N = N self.__n_output = n_output self.__input_weight = input_weight self.__output_weight = output_weight + @classmethod + def register_update(cls, key): + """Decorator for registration of an `update` callable. Required for + :class:`.Dynamics` serialization. + + Parameters + ---------- + + key : str + Key to associated with the callable. + + Returns + ------- + + callable + """ + + def wrapper(fn): + fn._bt_ocean__update_key = key + cls._update_registry[key] = fn + return fn + return wrapper + def compute_output_shape(self, input_shape): return (input_shape[0], self.__n_output) + input_shape[1:] def call(self, inputs): @jax.checkpoint def step(_, dynamics): - dynamics.fields["Q"] = self.__Q_0 + self.__Q_callback(dynamics, self.__Q_network) + self.__update(dynamics, *self.__args) dynamics.step() return dynamics @@ -249,6 +265,28 @@ def compute_outputs(zeta): outputs = jax.vmap(compute_outputs)(inputs) return outputs + def get_config(self): + return update_config( + super().get_config(), + {"_bt_ocean__Dynamics_config": {"dynamics": self.__dynamics, + "update_key": self.__update._bt_ocean__update_key, + "args": self.__args, + "N": self.__N, + "n_output": self.__n_output, + "input_weight": self.__input_weight, + "output_weight": self.__output_weight}}) + + @classmethod + def from_config(cls, config): + sub_config = {key: keras.saving.deserialize_keras_object(value) for key, value in config.pop("_bt_ocean__Dynamics_config").items()} + return cls(sub_config["dynamics"], + cls._update_registry[sub_config["update_key"]], + *sub_config["args"], + N=sub_config["N"], + n_output=sub_config["n_output"], + input_weight=sub_config["input_weight"], + output_weight=sub_config["output_weight"], **config) + class OnlineDataset(keras.utils.PyDataset): """Online training data set. diff --git a/docs/source/examples/1_keras_integration.ipynb b/docs/source/examples/1_keras_integration.ipynb index d56347f..7cede91 100644 --- a/docs/source/examples/1_keras_integration.ipynb +++ b/docs/source/examples/1_keras_integration.ipynb @@ -72,7 +72,7 @@ "source": [ "## A simple Keras model\n", "\n", - "We now set up our Keras model. To do this we first define a `keras.models.Model` and callable, which together define a map from our state to a forcing term appearing on the right-hand-side of the barotropic vorticity equation." + "We now set up our Keras model. To do this we first define a `keras.models.Model` which defines a map from our state to a forcing term which we add on the right-hand-side of the barotropic vorticity equation, and a callable which uses this map." ] }, { @@ -102,8 +102,8 @@ "Q_weight = tau_0 * jnp.pi / (D * rho_0 * model.grid.L_y)\n", "\n", "\n", - "def Q_callback(dynamics, Q_network):\n", - " return Q_weight * Q_network(jnp.zeros(shape=(1, 0)))[0, :, :]" + "def update(dynamics, Q_network):\n", + " dynamics.fields[\"Q\"] = Q_weight * Q_network(jnp.zeros(shape=(1, 0)))[0, :, :]" ] }, { @@ -111,7 +111,7 @@ "id": "d23edf22-f916-41fc-8095-2bdd8d9be6c9", "metadata": {}, "source": [ - "We now set up a `Dynamics` layer. This is a custom Keras layer which represents the mapping from an initial condition, in terms of the initial vorticity fields, to dynamical trajectories, in terms of the vorticity fields evaluated at later times. The trajectories are computed by solving the barotropic vorticity equation, while being forced with an extra right-hand-side term defined by the given `keras.models.Model`. We indicate that the `Dynamics` layer should use a known right-hand-side forcing term (the wind stress term, defined by the `Q_0` argument) of zero – meaning that in this example we will seek `Q_network` weights which reconstruct this term.\n", + "We now set up a `Dynamics` layer. This is a custom Keras layer which represents the mapping from an initial condition, in terms of the initial vorticity fields, to dynamical trajectories, in terms of the vorticity fields evaluated at later times. The trajectories are computed by solving the barotropic vorticity equation, while being forced with an extra right-hand-side term defined by the given `keras.models.Model`.\n", "\n", "In the following we use the AdamW optimizer, but here we have only one batch of size one. In fact here we are solving a standard variational optimization problem, but without an explicit regularization term, and so for this problem it might be better to use a deterministic optimizer. We do, however, increase the learning rate.\n", "\n", @@ -130,8 +130,7 @@ "output_weight = (model.grid.N_x + 1) * (model.grid.N_y + 1) * jnp.sqrt(model.grid.W / (4 * model.grid.L_x * model.grid.L_y)) / (model.beta * model.grid.L_y)\n", "dynamics_input_layer = keras.layers.Input((model.grid.N_x + 1, model.grid.N_y + 1))\n", "dynamics_layer = Dynamics(\n", - " model, Q_0=jnp.zeros_like(model.fields[\"Q\"]), Q_network=Q_network, Q_callback=Q_callback,\n", - " N=1, n_output=N, output_weight=output_weight)\n", + " model, update, Q_network, N=1, n_output=N, output_weight=output_weight)\n", "dynamics_network = keras.models.Model(\n", " inputs=dynamics_input_layer, outputs=dynamics_layer(dynamics_input_layer))\n", "dynamics_network.compile(optimizer=keras.optimizers.AdamW(learning_rate=0.1),\n", diff --git a/tests/test_io.py b/tests/test_io.py index fd2b98d..e70cfd4 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -1,6 +1,6 @@ from bt_ocean.grid import Grid from bt_ocean.model import CNAB2Solver, Fields, Parameters, Solver -from bt_ocean.parameters import parameters +from bt_ocean.parameters import parameters, Q import jax.numpy as jnp import zarr @@ -63,42 +63,42 @@ def test_fields_roundtrip(tmp_path): def test_solver_roundtrip(tmp_path): - solver = CNAB2Solver(model_parameters()) - + model = CNAB2Solver(model_parameters()) + model.fields["Q"] = Q(model.grid) for _ in range(5): - solver.step() + model.step() filename = tmp_path / "tmp.zarr" with zarr.open(filename, "w") as h: - solver.write(h) + model.write(h) with zarr.open(filename, "r") as h: - input_solver = Solver.read(h) - - assert type(input_solver) is type(solver) - assert input_solver.n == solver.n - - assert input_solver.grid.L_x == solver.grid.L_x - assert input_solver.grid.L_y == solver.grid.L_y - assert input_solver.grid.N_x == solver.grid.N_x - assert input_solver.grid.N_y == solver.grid.N_y - assert input_solver.grid.idtype == solver.grid.idtype - assert input_solver.grid.fdtype == solver.grid.fdtype - - assert input_solver.dealias_grid.L_x == solver.dealias_grid.L_x - assert input_solver.dealias_grid.L_y == solver.dealias_grid.L_y - assert input_solver.dealias_grid.N_x == solver.dealias_grid.N_x - assert input_solver.dealias_grid.N_y == solver.dealias_grid.N_y - assert input_solver.dealias_grid.idtype == solver.dealias_grid.idtype - assert input_solver.dealias_grid.fdtype == solver.dealias_grid.fdtype - - assert set(input_solver.parameters) == set(solver.parameters) - for key, value in solver.parameters.items(): - assert input_solver.parameters[key] == value - - assert set(input_solver.fields) == set(solver.fields) - for key, value in solver.fields.items(): - assert (input_solver.fields[key] == value).all() - - assert set(input_solver.dealias_fields) == set(solver.dealias_fields) - for key, value in solver.dealias_fields.items(): - assert (input_solver.dealias_fields[key] == value).all() + input_model = Solver.read(h) + + assert type(input_model) is type(model) + assert input_model.n == model.n + + assert input_model.grid.L_x == model.grid.L_x + assert input_model.grid.L_y == model.grid.L_y + assert input_model.grid.N_x == model.grid.N_x + assert input_model.grid.N_y == model.grid.N_y + assert input_model.grid.idtype == model.grid.idtype + assert input_model.grid.fdtype == model.grid.fdtype + + assert input_model.dealias_grid.L_x == model.dealias_grid.L_x + assert input_model.dealias_grid.L_y == model.dealias_grid.L_y + assert input_model.dealias_grid.N_x == model.dealias_grid.N_x + assert input_model.dealias_grid.N_y == model.dealias_grid.N_y + assert input_model.dealias_grid.idtype == model.dealias_grid.idtype + assert input_model.dealias_grid.fdtype == model.dealias_grid.fdtype + + assert set(input_model.parameters) == set(model.parameters) + for key, value in model.parameters.items(): + assert input_model.parameters[key] == value + + assert set(input_model.fields) == set(model.fields) + for key, value in model.fields.items(): + assert (input_model.fields[key] == value).all() + + assert set(input_model.dealias_fields) == set(model.dealias_fields) + for key, value in model.dealias_fields.items(): + assert (input_model.dealias_fields[key] == value).all() diff --git a/tests/test_network.py b/tests/test_network.py index f7fc560..bbe4872 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -5,11 +5,23 @@ import keras from numpy import sqrt -from bt_ocean.network import KroneckerProduct, Scale +from bt_ocean.model import CNAB2Solver, Parameters +from bt_ocean.network import Dynamics, KroneckerProduct, Scale +from bt_ocean.parameters import parameters, Q from .test_base import test_precision # noqa: F401 +def model_parameters(): + n_hour = 1 + model_parameters = dict(parameters) + model_parameters["dt"] = 3600 / n_hour + model_parameters["N_x"] = 32 + model_parameters["N_y"] = 32 + model_parameters["nu"] = 1.0e5 + return Parameters(model_parameters) + + @pytest.mark.parametrize("alpha", [-sqrt(2), sqrt(3)]) def test_scale_roundtrip(tmp_path, alpha): input_layer = keras.layers.Input((3, 2)) @@ -72,3 +84,65 @@ def test_kronecker_product_roundtrip(tmp_path, activation, symmetric, bias): assert w_i.shape == w_j.shape assert w_i.dtype == w_j.dtype assert (w_i == w_j).all() + + +def test_dynamics_roundtrip(tmp_path): + model = CNAB2Solver(model_parameters()) + model.fields["Q"] = Q(model.grid) + for _ in range(5): + model.step() + + Q_input_layer = keras.layers.Input((model.grid.N_x + 1, model.grid.N_y + 1)) + Q_network = keras.models.Model(inputs=Q_input_layer, outputs=Q_input_layer) + + n_calls = 0 + + @Dynamics.register_update("test_dynamics_roundtrip_Q_callback") + def Q_callback(dynamics, Q_network): + nonlocal n_calls + n_calls += 1 + + dynamics_layer = Dynamics(model, Q_callback, Q_network, N=1) + dynamics_input_layer = keras.layers.Input((model.grid.N_x + 1, model.grid.N_y + 1)) + dynamics_network = keras.models.Model(inputs=dynamics_input_layer, outputs=dynamics_layer(dynamics_input_layer)) + + assert n_calls == 0 + dynamics_network(jnp.zeros((1, model.grid.N_x + 1, model.grid.N_x + 1))) + assert n_calls == 1 + + dynamics_network.save(tmp_path / "tmp.keras") + dynamics_network = keras.models.load_model(tmp_path / "tmp.keras") + + input_model = dynamics_network.layers[1]._Dynamics__dynamics + + assert type(input_model) is type(model) + assert input_model.n == model.n + + assert input_model.grid.L_x == model.grid.L_x + assert input_model.grid.L_y == model.grid.L_y + assert input_model.grid.N_x == model.grid.N_x + assert input_model.grid.N_y == model.grid.N_y + assert input_model.grid.idtype == model.grid.idtype + assert input_model.grid.fdtype == model.grid.fdtype + + assert input_model.dealias_grid.L_x == model.dealias_grid.L_x + assert input_model.dealias_grid.L_y == model.dealias_grid.L_y + assert input_model.dealias_grid.N_x == model.dealias_grid.N_x + assert input_model.dealias_grid.N_y == model.dealias_grid.N_y + assert input_model.dealias_grid.idtype == model.dealias_grid.idtype + assert input_model.dealias_grid.fdtype == model.dealias_grid.fdtype + + assert set(input_model.parameters) == set(model.parameters) + for key, value in model.parameters.items(): + assert input_model.parameters[key] == value + + assert set(input_model.fields) == set(model.fields) + for key, value in model.fields.items(): + assert (input_model.fields[key] == value).all() + + assert set(input_model.dealias_fields) == set(model.dealias_fields) + for key, value in model.dealias_fields.items(): + assert (input_model.dealias_fields[key] == value).all() + + dynamics_network(jnp.zeros((1, model.grid.N_x + 1, model.grid.N_x + 1))) + assert n_calls == 2 From d0c942d40fcdf960e0ee060aba5fdb22cd5df199 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 29 Oct 2024 11:47:06 +0000 Subject: [PATCH 8/8] Tidying --- bt_ocean/model.py | 2 +- bt_ocean/network.py | 2 +- tests/test_network.py | 6 +++--- 3 files changed, 5 insertions(+), 5 deletions(-) diff --git a/bt_ocean/model.py b/bt_ocean/model.py index 1d5678b..3df4c14 100644 --- a/bt_ocean/model.py +++ b/bt_ocean/model.py @@ -955,7 +955,7 @@ def ke(self): v = self.dealias_fields["v"] return 0.5 * jnp.tensordot((u * u + v * v), self.dealias_grid.W) - def steady_state_solve(self, *args, update=lambda model, *args: None, tol, max_it=10000, _min_n=0): + def steady_state_solve(self, *args, update=lambda model, *args: None, tol, max_it=10000): return super().steady_state_solve(*args, update=update, tol=tol, max_it=max_it, _min_n=1) def new(self): diff --git a/bt_ocean/network.py b/bt_ocean/network.py index aa3bb68..859982c 100644 --- a/bt_ocean/network.py +++ b/bt_ocean/network.py @@ -228,7 +228,7 @@ def register_update(cls, key): ---------- key : str - Key to associated with the callable. + Key to associated with the `update` callable. Returns ------- diff --git a/tests/test_network.py b/tests/test_network.py index bbe4872..263d61d 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -102,12 +102,12 @@ def Q_callback(dynamics, Q_network): nonlocal n_calls n_calls += 1 - dynamics_layer = Dynamics(model, Q_callback, Q_network, N=1) + dynamics_layer = Dynamics(model, Q_callback, Q_network) dynamics_input_layer = keras.layers.Input((model.grid.N_x + 1, model.grid.N_y + 1)) dynamics_network = keras.models.Model(inputs=dynamics_input_layer, outputs=dynamics_layer(dynamics_input_layer)) assert n_calls == 0 - dynamics_network(jnp.zeros((1, model.grid.N_x + 1, model.grid.N_x + 1))) + dynamics_network(jnp.zeros((1, model.grid.N_x + 1, model.grid.N_y + 1))) assert n_calls == 1 dynamics_network.save(tmp_path / "tmp.keras") @@ -144,5 +144,5 @@ def Q_callback(dynamics, Q_network): for key, value in model.dealias_fields.items(): assert (input_model.dealias_fields[key] == value).all() - dynamics_network(jnp.zeros((1, model.grid.N_x + 1, model.grid.N_x + 1))) + dynamics_network(jnp.zeros((1, model.grid.N_x + 1, model.grid.N_y + 1))) assert n_calls == 2