diff --git a/bt_ocean/model.py b/bt_ocean/model.py index f965ed3..3df4c14 100644 --- a/bt_ocean/model.py +++ b/bt_ocean/model.py @@ -4,6 +4,7 @@ import jax import jax.numpy as jnp +import keras import numpy as np from abc import ABC, abstractmethod @@ -18,16 +19,13 @@ [ "Parameters", "required", - "read_parameters", "Fields", - "read_fields", "SteadyStateMaximumIterationsError", "NanEncounteredError", "Solver", - "CNAB2Solver", - "read_solver" + "CNAB2Solver" ] @@ -93,26 +91,26 @@ def write(self, h, path="parameters"): g.attrs.update(self.items()) return g + @classmethod + def read(cls, h, path="parameters"): + """Read parameters from a :class:`zarr.hierarchy.Group`. -def read_parameters(h, path="parameters"): - """Read parameters from a :class:`zarr.hierarchy.Group`. - - Parameters - ---------- + Parameters + ---------- - h : :class:`zarr.hierarchy.Group` - Parent group. - path : str - Group path. + h : :class:`zarr.hierarchy.Group` + Parent group. + path : str + Group path. - Returns - ------- + Returns + ------- - :class:`.Parameters` - The parameters. - """ + :class:`.Parameters` + The parameters. + """ - return Parameters(h[path].attrs) + return cls(h[path].attrs) class Fields(Mapping): @@ -254,50 +252,50 @@ def update(self, d): for key, value in d.items(): self[key] = value + @classmethod + def read(cls, h, path="fields", *, grid=None): + """Read fields from a :class:`zarr.hierarchy.Group`. -def read_fields(h, path="fields", *, grid=None): - """Read fields from a :class:`zarr.hierarchy.Group`. - - Parameters - ---------- + Parameters + ---------- - h : :class:`zarr.hierarchy.Group` - Parent group. - path : str - Group path. - grid : :class:`.Grid` - The 2D Chebyshev grid. + h : :class:`zarr.hierarchy.Group` + Parent group. + path : str + Group path. + grid : :class:`.Grid` + The 2D Chebyshev grid. - Returns - ------- + Returns + ------- - :class:`.Fields` - The fields. - """ + :class:`.Fields` + The fields. + """ - g = h[path] - del h + g = h[path] + del h - L_x = g.attrs["L_x"] - L_y = g.attrs["L_y"] - N_x = g.attrs["N_x"] - N_y = g.attrs["N_y"] - idtype = jnp.dtype(g.attrs["idtype"]).type - fdtype = jnp.dtype(g.attrs["fdtype"]).type - if grid is None: - grid = Grid(L_x, L_y, N_x, N_y, idtype=idtype, fdtype=fdtype) - if L_x != grid.L_x or L_y != grid.L_y: - raise ValueError("Invalid dimension(s)") - if N_x != grid.N_x or N_y != grid.N_y: - raise ValueError("Invalid degree(s)") - if idtype != grid.idtype or fdtype != grid.fdtype: - raise ValueError("Invalid dtype(s)") + L_x = g.attrs["L_x"] + L_y = g.attrs["L_y"] + N_x = g.attrs["N_x"] + N_y = g.attrs["N_y"] + idtype = jnp.dtype(g.attrs["idtype"]).type + fdtype = jnp.dtype(g.attrs["fdtype"]).type + if grid is None: + grid = Grid(L_x, L_y, N_x, N_y, idtype=idtype, fdtype=fdtype) + if L_x != grid.L_x or L_y != grid.L_y: + raise ValueError("Invalid dimension(s)") + if N_x != grid.N_x or N_y != grid.N_y: + raise ValueError("Invalid degree(s)") + if idtype != grid.idtype or fdtype != grid.fdtype: + raise ValueError("Invalid dtype(s)") - fields = Fields(grid, set(g)) - for key in g: - fields[key] = g[key][...] + fields = cls(grid, set(g)) + for key in g: + fields[key] = g[key][...] - return fields + return fields class SteadyStateMaximumIterationsError(Exception): @@ -415,15 +413,16 @@ def __init__(self, parameters, *, field_keys=None, dealias_field_keys=None, def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) - def flatten(solver): - return solver.flatten() + def flatten(model): + return model.flatten() def unflatten(aux_data, children): return cls.unflatten(aux_data, children) jax.tree_util.register_pytree_node(cls, flatten, unflatten) - Solver._registry[cls.__name__] = cls + cls._registry[cls.__name__] = cls + keras.saving.register_keras_serializable(package=f"_bt_ocean__{cls.__name__}")(cls) @cached_property def beta(self) -> jax.Array: @@ -593,7 +592,7 @@ def ke_spectrum(self, N_x, N_y): return (0.5 * self.grid.L_x * self.grid.L_y * (k ** 2 + l ** 2) * (dst(dst(psi, axis=1), axis=0) ** 2)) - def steady_state_solve(self, m=(), update=lambda model, *m: None, *, tol, max_it=10000, _min_n=0): + def steady_state_solve(self, *args, update=lambda model, *args: None, tol, max_it=10000, _min_n=0): r"""Timestep to steady-state. Uses timestepping to define a fixed-point iteration, and applies @@ -615,12 +614,12 @@ def steady_state_solve(self, m=(), update=lambda model, *m: None, *, tol, max_it Parameters ---------- - m : Sequence[:class:`jax.Array`, ...] - Additional control variables. update : callable A callable accepting a :class:`.Solver` as the zeroth argument and the elements of `m` as remaining positional arguments, and which updates the values of control variables. + args : tuple + Passed to `update`. tol : Real Tolerance. The system is timestepped until @@ -642,18 +641,18 @@ def steady_state_solve(self, m=(), update=lambda model, *m: None, *, tol, max_it """ @jax.jit - def forward_step(data, m): + def forward_step(data, args): _, model, it = data zeta_0 = model.fields["zeta"] - update(model, *m) + update(model, *args) model.step() return (zeta_0, model, it + 1) @jax.custom_vjp - def forward(model, m): + def forward(model, args): while model.n < _min_n: - zeta_0, model, _ = forward_step((None, model, 0), m) - zeta_0, model, _ = forward_step((None, model, 0), m) + zeta_0, model, _ = forward_step((None, model, 0), args) + zeta_0, model, _ = forward_step((None, model, 0), args) def non_convergence(data): zeta_0, model, it = data @@ -663,20 +662,20 @@ def non_convergence(data): abs(zeta_1 - zeta_0).max() > tol * abs(zeta_1).max()) _, model, it = jax.lax.while_loop( - non_convergence, partial(forward_step, m=m), + non_convergence, partial(forward_step, args=args), (zeta_0, model, 1)) return model, it - def forward_fwd(model, m): - model, it = forward(model, m) - return (model, it), (model, m) + def forward_fwd(model, args): + model, it = forward(model, args) + return (model, it), (model, args) def forward_bwd(res, zeta): - model, m = res + model, args = res zeta_model, _ = zeta _, vjp_step = jax.vjp( - lambda model: forward_step((None, model, 0), m)[1], model) + lambda model: forward_step((None, model, 0), args)[1], model) @jax.jit def adj_step(data, zeta_model): @@ -709,14 +708,14 @@ def non_convergence(data): return lam_model lam_model = adjoint(zeta_model) - _, vjp = jax.vjp(lambda m: forward_step((None, model, 0), m)[1], m) - lam_m, = vjp(lam_model) + _, vjp = jax.vjp(lambda args: forward_step((None, model, 0), args)[1], args) + lam_args, = vjp(lam_model) - return lam_model, lam_m + return lam_model, lam_args forward.defvjp(forward_fwd, forward_bwd) - model, it = forward(self, m) + model, it = forward(self, args) self.update(model) if it > max_it: raise SteadyStateMaximumIterationsError("Maximum number of iterations exceeded") @@ -752,6 +751,36 @@ def write(self, h, path="solver"): return g + @classmethod + def read(cls, h, path="solver"): + """Read solver from a :class:`zarr.hierarchy.Group`. + + Parameters + ---------- + + h : :class:`zarr.hierarchy.Group` + Parent group. + path : str + Group path. + + Returns + ------- + + :class:`.Solver` + The solver. + """ + + g = h[path] + del h + + cls = cls._registry[g.attrs["type"]] + model = cls(Parameters.read(g, "parameters")) + model.fields.update(Fields.read(g, "fields", grid=model.grid)) + model.dealias_fields.update(Fields.read(g, "dealias_fields", grid=model.dealias_grid)) + model.n = g.attrs["n"] + + return model + def new(self): """Return a new :class:`.Solver` with the same configuration as this :class:`.Solver`. @@ -763,23 +792,23 @@ def new(self): The new :class:`.Solver`. """ - solver = type(self)(self.parameters) - solver.poisson_solver = self.poisson_solver - return solver + model = type(self)(self.parameters) + model.poisson_solver = self.poisson_solver + return model - def update(self, solver): + def update(self, model): """Update the state of this :class:`.Solver`. Parameters ---------- - solver : :class:`.Solver` + model : :class:`.Solver` Defines the new state of this :class:`.Solver`. """ - self.fields.update(solver.fields) - self.dealias_fields.update(solver.dealias_fields) - self.n = solver.n + self.fields.update(model.fields) + self.dealias_fields.update(model.dealias_fields) + self.n = model.n def flatten(self): """Return a JAX flattened representation. @@ -791,54 +820,41 @@ def flatten(self): """ return ((dict(self.fields), dict(self.dealias_fields), self.n), - (type(self), self.parameters, self.poisson_solver)) + (self.parameters, self.poisson_solver)) - @staticmethod - def unflatten(aux_data, children): + @classmethod + def unflatten(cls, aux_data, children): """Unpack a JAX flattened representation. """ - cls, parameters, poisson_solver = aux_data + parameters, poisson_solver = aux_data fields, dealias_fields, n = children - solver = cls(parameters) - solver.poisson_solver = poisson_solver - solver.fields.update({key: value for key, value in fields.items() if type(value) is not object}) - solver.dealias_fields.update({key: value for key, value in dealias_fields.items() if type(value) is not object}) + model = cls(parameters) + model.poisson_solver = poisson_solver + model.fields.update({key: value for key, value in fields.items() if type(value) is not object}) + model.dealias_fields.update({key: value for key, value in dealias_fields.items() if type(value) is not object}) if type(n) is not object: - solver.n = n - - return solver - - -def read_solver(h, path="solver"): - """Read solver from a :class:`zarr.hierarchy.Group`. - - Parameters - ---------- - - h : :class:`zarr.hierarchy.Group` - Parent group. - path : str - Group path. - - Returns - ------- - - :class:`.Solver` - The solver. - """ + model.n = n - g = h[path] - del h + return model - cls = Solver._registry[g.attrs["type"]] - solver = cls(read_parameters(g, "parameters")) - solver.fields.update(read_fields(g, "fields", grid=solver.grid)) - solver.dealias_fields.update(read_fields(g, "dealias_fields", grid=solver.dealias_grid)) - solver.n = g.attrs["n"] + def get_config(self): + return {"type": type(self).__name__, + "parameters": dict(self.parameters), + "fields": dict(self.fields), + "dealias_fields": dict(self.dealias_fields), + "n": self.n} - return solver + @classmethod + def from_config(cls, config): + config = {key: keras.saving.deserialize_keras_object(value) for key, value in config.items()} + cls = cls._registry[config["type"]] + model = cls(config["parameters"]) + model.fields.update(config["fields"]) + model.dealias_fields.update(config["dealias_fields"]) + model.n = config["n"] + return model class CNAB2Solver(Solver): @@ -939,20 +955,20 @@ def ke(self): v = self.dealias_fields["v"] return 0.5 * jnp.tensordot((u * u + v * v), self.dealias_grid.W) - def steady_state_solve(self, m=(), update=lambda model, *m: None, *, tol, max_it=10000): - return super().steady_state_solve(m=m, update=update, tol=tol, _min_n=1, max_it=max_it) + def steady_state_solve(self, *args, update=lambda model, *args: None, tol, max_it=10000): + return super().steady_state_solve(*args, update=update, tol=tol, max_it=max_it, _min_n=1) def new(self): - solver = super().new() - solver.modified_helmholtz_solver = self.modified_helmholtz_solver - return solver + model = super().new() + model.modified_helmholtz_solver = self.modified_helmholtz_solver + return model def flatten(self): children, aux_data = super().flatten() return children, aux_data + (self.modified_helmholtz_solver,) - @staticmethod - def unflatten(aux_data, children): - solver = Solver.unflatten(aux_data[:-1], children) - solver.modified_helmholtz_solver = aux_data[-1] - return solver + @classmethod + def unflatten(cls, aux_data, children): + model = super().unflatten(aux_data[:-1], children) + model.modified_helmholtz_solver = aux_data[-1] + return model diff --git a/bt_ocean/network.py b/bt_ocean/network.py index bd2415b..859982c 100644 --- a/bt_ocean/network.py +++ b/bt_ocean/network.py @@ -19,7 +19,16 @@ ] -@keras.saving.register_keras_serializable(package="bt_ocean_scale") +def update_config(config, d): + config = dict(config) + for key, value in d.items(): + if key in config: + raise KeyError(f"key '{key}' already defined") + config[key] = value + return config + + +@keras.saving.register_keras_serializable(package="_bt_ocean__Scale") class Scale(keras.layers.Layer): """A layer which multiplies by a constant trainable weight. """ @@ -35,7 +44,7 @@ def build(self, input_shape): pass -@keras.saving.register_keras_serializable(package="bt_ocean_kronecker_product") +@keras.saving.register_keras_serializable(package="_bt_ocean__KroneckerProduct") class KroneckerProduct(keras.layers.Layer): """A layer where the weights matrix has Kronecker product structure. @@ -86,22 +95,6 @@ def call(self, inputs): outputs = self.__activation(outputs) return outputs - def get_config(self): - def update(config, d): - config = dict(config) - for key, value in d.items(): - if key in config: - raise KeyError(f"key '{key}' already defined") - config[key] = value - return config - - return update(super().get_config(), - {"shape_a": self.__shape_a, - "shape_b": self.__shape_b, - "activation": self.__activation_arg, - "symmetric": self.__symmetric, - "bias": self.__bias}) - def build(self, input_shape): if tuple(input_shape)[-2:] != self.__shape_a: raise ValueError("Invalid shape") @@ -186,23 +179,20 @@ def kronecker_product_network( return keras.models.Model(inputs=input_layer, outputs=output_layer) +@keras.saving.register_keras_serializable(package="_bt_ocean__Dynamics") class Dynamics(keras.layers.Layer): - """Defines a layer consisting of a dynamical core with a neural network - parameterized forcing. + """Defines a layer consisting of a dynamical solver. Parameters ---------- dynamics : :class:`.Solver` - The dynamical core. - Q_0 : :class:`jax.numpy.Array` - Wind forcing term in the vorticity equation. - Q_network - The right-hand-side forcing neural network. - Q_callback : callable - Passed `dynamics` and `Q_network`, and should return an additional term - to be added to the right-hand-side of the vorticity equation. Evaluated - before taking each timestep. + The dynamical solver. + update : callable + Passed `dynamics` and any arguments defined by `args`, and should + update the state of `dynamics`. Evaluated before taking each timestep. + args : tuple + Passed as remaining arguments to `update`. N : Integral The number of timesteps to take using the dynamical solver between each output. @@ -214,25 +204,51 @@ class Dynamics(keras.layers.Layer): Weight by which to scale each output. """ - def __init__(self, dynamics, Q_0, Q_network, Q_callback, N, *, n_output=1, - input_weight=1, output_weight=1): - super().__init__(dtype=dynamics.grid.fdtype) + _update_registry = {} + + def __init__(self, dynamics, update, *args, N=1, n_output=1, + input_weight=1, output_weight=1, **kwargs): + if "dtype" not in kwargs: + kwargs["dtype"] = dynamics.grid.fdtype + super().__init__(**kwargs) self.__dynamics = dynamics - self.__Q_0 = Q_0 - self.__Q_network = Q_network - self.__Q_callback = Q_callback + self.__update = update + self.__args = args self.__N = N self.__n_output = n_output self.__input_weight = input_weight self.__output_weight = output_weight + @classmethod + def register_update(cls, key): + """Decorator for registration of an `update` callable. Required for + :class:`.Dynamics` serialization. + + Parameters + ---------- + + key : str + Key to associated with the `update` callable. + + Returns + ------- + + callable + """ + + def wrapper(fn): + fn._bt_ocean__update_key = key + cls._update_registry[key] = fn + return fn + return wrapper + def compute_output_shape(self, input_shape): return (input_shape[0], self.__n_output) + input_shape[1:] def call(self, inputs): @jax.checkpoint def step(_, dynamics): - dynamics.fields["Q"] = self.__Q_0 + self.__Q_callback(dynamics, self.__Q_network) + self.__update(dynamics, *self.__args) dynamics.step() return dynamics @@ -249,6 +265,28 @@ def compute_outputs(zeta): outputs = jax.vmap(compute_outputs)(inputs) return outputs + def get_config(self): + return update_config( + super().get_config(), + {"_bt_ocean__Dynamics_config": {"dynamics": self.__dynamics, + "update_key": self.__update._bt_ocean__update_key, + "args": self.__args, + "N": self.__N, + "n_output": self.__n_output, + "input_weight": self.__input_weight, + "output_weight": self.__output_weight}}) + + @classmethod + def from_config(cls, config): + sub_config = {key: keras.saving.deserialize_keras_object(value) for key, value in config.pop("_bt_ocean__Dynamics_config").items()} + return cls(sub_config["dynamics"], + cls._update_registry[sub_config["update_key"]], + *sub_config["args"], + N=sub_config["N"], + n_output=sub_config["n_output"], + input_weight=sub_config["input_weight"], + output_weight=sub_config["output_weight"], **config) + class OnlineDataset(keras.utils.PyDataset): """Online training data set. diff --git a/docs/source/examples/1_keras_integration.ipynb b/docs/source/examples/1_keras_integration.ipynb index d56347f..7cede91 100644 --- a/docs/source/examples/1_keras_integration.ipynb +++ b/docs/source/examples/1_keras_integration.ipynb @@ -72,7 +72,7 @@ "source": [ "## A simple Keras model\n", "\n", - "We now set up our Keras model. To do this we first define a `keras.models.Model` and callable, which together define a map from our state to a forcing term appearing on the right-hand-side of the barotropic vorticity equation." + "We now set up our Keras model. To do this we first define a `keras.models.Model` which defines a map from our state to a forcing term which we add on the right-hand-side of the barotropic vorticity equation, and a callable which uses this map." ] }, { @@ -102,8 +102,8 @@ "Q_weight = tau_0 * jnp.pi / (D * rho_0 * model.grid.L_y)\n", "\n", "\n", - "def Q_callback(dynamics, Q_network):\n", - " return Q_weight * Q_network(jnp.zeros(shape=(1, 0)))[0, :, :]" + "def update(dynamics, Q_network):\n", + " dynamics.fields[\"Q\"] = Q_weight * Q_network(jnp.zeros(shape=(1, 0)))[0, :, :]" ] }, { @@ -111,7 +111,7 @@ "id": "d23edf22-f916-41fc-8095-2bdd8d9be6c9", "metadata": {}, "source": [ - "We now set up a `Dynamics` layer. This is a custom Keras layer which represents the mapping from an initial condition, in terms of the initial vorticity fields, to dynamical trajectories, in terms of the vorticity fields evaluated at later times. The trajectories are computed by solving the barotropic vorticity equation, while being forced with an extra right-hand-side term defined by the given `keras.models.Model`. We indicate that the `Dynamics` layer should use a known right-hand-side forcing term (the wind stress term, defined by the `Q_0` argument) of zero – meaning that in this example we will seek `Q_network` weights which reconstruct this term.\n", + "We now set up a `Dynamics` layer. This is a custom Keras layer which represents the mapping from an initial condition, in terms of the initial vorticity fields, to dynamical trajectories, in terms of the vorticity fields evaluated at later times. The trajectories are computed by solving the barotropic vorticity equation, while being forced with an extra right-hand-side term defined by the given `keras.models.Model`.\n", "\n", "In the following we use the AdamW optimizer, but here we have only one batch of size one. In fact here we are solving a standard variational optimization problem, but without an explicit regularization term, and so for this problem it might be better to use a deterministic optimizer. We do, however, increase the learning rate.\n", "\n", @@ -130,8 +130,7 @@ "output_weight = (model.grid.N_x + 1) * (model.grid.N_y + 1) * jnp.sqrt(model.grid.W / (4 * model.grid.L_x * model.grid.L_y)) / (model.beta * model.grid.L_y)\n", "dynamics_input_layer = keras.layers.Input((model.grid.N_x + 1, model.grid.N_y + 1))\n", "dynamics_layer = Dynamics(\n", - " model, Q_0=jnp.zeros_like(model.fields[\"Q\"]), Q_network=Q_network, Q_callback=Q_callback,\n", - " N=1, n_output=N, output_weight=output_weight)\n", + " model, update, Q_network, N=1, n_output=N, output_weight=output_weight)\n", "dynamics_network = keras.models.Model(\n", " inputs=dynamics_input_layer, outputs=dynamics_layer(dynamics_input_layer))\n", "dynamics_network.compile(optimizer=keras.optimizers.AdamW(learning_rate=0.1),\n", diff --git a/docs/source/examples/2_steady_state.ipynb b/docs/source/examples/2_steady_state.ipynb index ebb5790..e3f584e 100644 --- a/docs/source/examples/2_steady_state.ipynb +++ b/docs/source/examples/2_steady_state.ipynb @@ -180,7 +180,7 @@ " model.fields[\"Q\"] = Q - (r - model.r) * model.fields[\"zeta\"]\n", "\n", " model = CNAB2Solver(parameters)\n", - " model.steady_state_solve(m=(r,), update=update, tol=1.0e-6)\n", + " model.steady_state_solve(r, update=update, tol=1.0e-6)\n", " return D * model.fields[\"psi\"].max() / 1.0e6\n", "\n", "\n", diff --git a/tests/test_io.py b/tests/test_io.py index 0f663e8..e70cfd4 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -1,7 +1,6 @@ from bt_ocean.grid import Grid -from bt_ocean.model import ( - CNAB2Solver, Fields, Parameters, read_fields, read_parameters, read_solver) -from bt_ocean.parameters import parameters +from bt_ocean.model import CNAB2Solver, Fields, Parameters, Solver +from bt_ocean.parameters import parameters, Q import jax.numpy as jnp import zarr @@ -25,7 +24,7 @@ def test_parameters_roundtrip(tmp_path): with zarr.open(filename, "w") as h: parameters.write(h) with zarr.open(filename, "r") as h: - input_parameters = read_parameters(h) + input_parameters = Parameters.read(h) assert set(input_parameters) == set(parameters) for key, value in parameters.items(): @@ -47,7 +46,7 @@ def test_fields_roundtrip(tmp_path): with zarr.open(filename, "w") as h: fields.write(h) with zarr.open(filename, "r") as h: - input_fields = read_fields(h) + input_fields = Fields.read(h) assert input_fields.grid.L_x == L_x assert input_fields.grid.L_y == L_y @@ -64,42 +63,42 @@ def test_fields_roundtrip(tmp_path): def test_solver_roundtrip(tmp_path): - solver = CNAB2Solver(model_parameters()) - + model = CNAB2Solver(model_parameters()) + model.fields["Q"] = Q(model.grid) for _ in range(5): - solver.step() + model.step() filename = tmp_path / "tmp.zarr" with zarr.open(filename, "w") as h: - solver.write(h) + model.write(h) with zarr.open(filename, "r") as h: - input_solver = read_solver(h) - - assert type(input_solver) is type(solver) - assert input_solver.n == solver.n - - assert input_solver.grid.L_x == solver.grid.L_x - assert input_solver.grid.L_y == solver.grid.L_y - assert input_solver.grid.N_x == solver.grid.N_x - assert input_solver.grid.N_y == solver.grid.N_y - assert input_solver.grid.idtype == solver.grid.idtype - assert input_solver.grid.fdtype == solver.grid.fdtype - - assert input_solver.dealias_grid.L_x == solver.dealias_grid.L_x - assert input_solver.dealias_grid.L_y == solver.dealias_grid.L_y - assert input_solver.dealias_grid.N_x == solver.dealias_grid.N_x - assert input_solver.dealias_grid.N_y == solver.dealias_grid.N_y - assert input_solver.dealias_grid.idtype == solver.dealias_grid.idtype - assert input_solver.dealias_grid.fdtype == solver.dealias_grid.fdtype - - assert set(input_solver.parameters) == set(solver.parameters) - for key, value in solver.parameters.items(): - assert input_solver.parameters[key] == value - - assert set(input_solver.fields) == set(solver.fields) - for key, value in solver.fields.items(): - assert (input_solver.fields[key] == value).all() - - assert set(input_solver.dealias_fields) == set(solver.dealias_fields) - for key, value in solver.dealias_fields.items(): - assert (input_solver.dealias_fields[key] == value).all() + input_model = Solver.read(h) + + assert type(input_model) is type(model) + assert input_model.n == model.n + + assert input_model.grid.L_x == model.grid.L_x + assert input_model.grid.L_y == model.grid.L_y + assert input_model.grid.N_x == model.grid.N_x + assert input_model.grid.N_y == model.grid.N_y + assert input_model.grid.idtype == model.grid.idtype + assert input_model.grid.fdtype == model.grid.fdtype + + assert input_model.dealias_grid.L_x == model.dealias_grid.L_x + assert input_model.dealias_grid.L_y == model.dealias_grid.L_y + assert input_model.dealias_grid.N_x == model.dealias_grid.N_x + assert input_model.dealias_grid.N_y == model.dealias_grid.N_y + assert input_model.dealias_grid.idtype == model.dealias_grid.idtype + assert input_model.dealias_grid.fdtype == model.dealias_grid.fdtype + + assert set(input_model.parameters) == set(model.parameters) + for key, value in model.parameters.items(): + assert input_model.parameters[key] == value + + assert set(input_model.fields) == set(model.fields) + for key, value in model.fields.items(): + assert (input_model.fields[key] == value).all() + + assert set(input_model.dealias_fields) == set(model.dealias_fields) + for key, value in model.dealias_fields.items(): + assert (input_model.dealias_fields[key] == value).all() diff --git a/tests/test_network.py b/tests/test_network.py index f7fc560..263d61d 100644 --- a/tests/test_network.py +++ b/tests/test_network.py @@ -5,11 +5,23 @@ import keras from numpy import sqrt -from bt_ocean.network import KroneckerProduct, Scale +from bt_ocean.model import CNAB2Solver, Parameters +from bt_ocean.network import Dynamics, KroneckerProduct, Scale +from bt_ocean.parameters import parameters, Q from .test_base import test_precision # noqa: F401 +def model_parameters(): + n_hour = 1 + model_parameters = dict(parameters) + model_parameters["dt"] = 3600 / n_hour + model_parameters["N_x"] = 32 + model_parameters["N_y"] = 32 + model_parameters["nu"] = 1.0e5 + return Parameters(model_parameters) + + @pytest.mark.parametrize("alpha", [-sqrt(2), sqrt(3)]) def test_scale_roundtrip(tmp_path, alpha): input_layer = keras.layers.Input((3, 2)) @@ -72,3 +84,65 @@ def test_kronecker_product_roundtrip(tmp_path, activation, symmetric, bias): assert w_i.shape == w_j.shape assert w_i.dtype == w_j.dtype assert (w_i == w_j).all() + + +def test_dynamics_roundtrip(tmp_path): + model = CNAB2Solver(model_parameters()) + model.fields["Q"] = Q(model.grid) + for _ in range(5): + model.step() + + Q_input_layer = keras.layers.Input((model.grid.N_x + 1, model.grid.N_y + 1)) + Q_network = keras.models.Model(inputs=Q_input_layer, outputs=Q_input_layer) + + n_calls = 0 + + @Dynamics.register_update("test_dynamics_roundtrip_Q_callback") + def Q_callback(dynamics, Q_network): + nonlocal n_calls + n_calls += 1 + + dynamics_layer = Dynamics(model, Q_callback, Q_network) + dynamics_input_layer = keras.layers.Input((model.grid.N_x + 1, model.grid.N_y + 1)) + dynamics_network = keras.models.Model(inputs=dynamics_input_layer, outputs=dynamics_layer(dynamics_input_layer)) + + assert n_calls == 0 + dynamics_network(jnp.zeros((1, model.grid.N_x + 1, model.grid.N_y + 1))) + assert n_calls == 1 + + dynamics_network.save(tmp_path / "tmp.keras") + dynamics_network = keras.models.load_model(tmp_path / "tmp.keras") + + input_model = dynamics_network.layers[1]._Dynamics__dynamics + + assert type(input_model) is type(model) + assert input_model.n == model.n + + assert input_model.grid.L_x == model.grid.L_x + assert input_model.grid.L_y == model.grid.L_y + assert input_model.grid.N_x == model.grid.N_x + assert input_model.grid.N_y == model.grid.N_y + assert input_model.grid.idtype == model.grid.idtype + assert input_model.grid.fdtype == model.grid.fdtype + + assert input_model.dealias_grid.L_x == model.dealias_grid.L_x + assert input_model.dealias_grid.L_y == model.dealias_grid.L_y + assert input_model.dealias_grid.N_x == model.dealias_grid.N_x + assert input_model.dealias_grid.N_y == model.dealias_grid.N_y + assert input_model.dealias_grid.idtype == model.dealias_grid.idtype + assert input_model.dealias_grid.fdtype == model.dealias_grid.fdtype + + assert set(input_model.parameters) == set(model.parameters) + for key, value in model.parameters.items(): + assert input_model.parameters[key] == value + + assert set(input_model.fields) == set(model.fields) + for key, value in model.fields.items(): + assert (input_model.fields[key] == value).all() + + assert set(input_model.dealias_fields) == set(model.dealias_fields) + for key, value in model.dealias_fields.items(): + assert (input_model.dealias_fields[key] == value).all() + + dynamics_network(jnp.zeros((1, model.grid.N_x + 1, model.grid.N_y + 1))) + assert n_calls == 2