From 39ef57fa30d594495a99927bad3b630cd9404c11 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 3 Dec 2024 12:37:19 +0000 Subject: [PATCH 1/2] Initialize prescribed fields in Dynamics.call --- bt_ocean/model.py | 29 +++++++++++++++++++++++------ bt_ocean/network.py | 2 +- 2 files changed, 24 insertions(+), 7 deletions(-) diff --git a/bt_ocean/model.py b/bt_ocean/model.py index 0365637..27ffb24 100644 --- a/bt_ocean/model.py +++ b/bt_ocean/model.py @@ -398,7 +398,7 @@ def __init__(self, parameters, *, idtype=None, fdtype=None, idtype=idtype, fdtype=fdtype) self._fields = Fields(grid, field_keys) - self._prescribed_field_keys = set(prescribed_field_keys) + self._prescribed_field_keys = tuple(sorted(prescribed_field_keys)) self.zero_prescribed() self.initialize() @@ -476,6 +476,13 @@ def fields(self) -> Fields: return self._fields + @property + def prescribed_field_keys(self) -> tuple: + """Keys of prescribed fields. + """ + + return self._prescribed_field_keys + @cached_property def poisson_solver(self) -> PoissonSolver: """Solver for the Poisson equation. @@ -487,7 +494,7 @@ def zero_prescribed(self): """Zero prescribed fields. """ - self.fields.zero(*self._prescribed_field_keys) + self.fields.zero(*self.prescribed_field_keys) @abstractmethod def initialize(self, zeta=None): @@ -501,7 +508,7 @@ def initialize(self, zeta=None): field. """ - self.fields.clear(keep_keys=self._prescribed_field_keys) + self.fields.clear(keep_keys=self.prescribed_field_keys) self._n = 0 @abstractmethod @@ -756,10 +763,17 @@ def read(cls, h, path="solver"): return model - def new(self): + def new(self, *, copy_prescribed=False): """Return a new :class:`.Solver` with the same configuration as this :class:`.Solver`. + Parameters + ---------- + + copy_prescribed : bool + Whether to copy values of prescribed fields to the new + :class:`.Solver`. + Returns ------- @@ -768,6 +782,9 @@ def new(self): """ model = type(self)(self.parameters, idtype=self.grid.idtype, fdtype=self.grid.fdtype) + if copy_prescribed: + for key in self.prescribed_field_keys: + model.fields[key] = self.fields[key] model.poisson_solver = self.poisson_solver return model @@ -922,8 +939,8 @@ def step(self): def steady_state_solve(self, *args, update=lambda model, *args: None, tol, max_it=10000): return super().steady_state_solve(*args, update=update, tol=tol, max_it=max_it, _min_n=1) - def new(self): - model = super().new() + def new(self, copy_prescribed=False): + model = super().new(copy_prescribed=copy_prescribed) model.modified_helmholtz_solver = self.modified_helmholtz_solver return model diff --git a/bt_ocean/network.py b/bt_ocean/network.py index 48f9286..75b9db0 100644 --- a/bt_ocean/network.py +++ b/bt_ocean/network.py @@ -177,7 +177,7 @@ def compute_step_outputs(dynamics, _): return dynamics, dynamics.fields["zeta"] * self.__output_weight def compute_outputs(zeta): - dynamics = self.__dynamics.new() + dynamics = self.__dynamics.new(copy_prescribed=True) dynamics.initialize(zeta * self.__input_weight) _, outputs = jax.lax.scan(compute_step_outputs, dynamics, None, length=self.__n_output) return outputs From e3ee9d6136cd17a542c6af438ea032fceaf16eae Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 3 Dec 2024 12:41:53 +0000 Subject: [PATCH 2/2] Transfer of elliptic solvers is no longer needed --- bt_ocean/model.py | 21 ++------------------- 1 file changed, 2 insertions(+), 19 deletions(-) diff --git a/bt_ocean/model.py b/bt_ocean/model.py index 27ffb24..c6ebe5a 100644 --- a/bt_ocean/model.py +++ b/bt_ocean/model.py @@ -785,7 +785,6 @@ def new(self, *, copy_prescribed=False): if copy_prescribed: for key in self.prescribed_field_keys: model.fields[key] = self.fields[key] - model.poisson_solver = self.poisson_solver return model def update(self, model): @@ -811,18 +810,17 @@ def flatten(self): """ return ((dict(self.fields), self.n), - (self.parameters, self.grid.idtype, self.grid.fdtype, self.poisson_solver)) + (self.parameters, self.grid.idtype, self.grid.fdtype)) @classmethod def unflatten(cls, aux_data, children): """Unpack a JAX flattened representation. """ - parameters, idtype, fdtype, poisson_solver = aux_data + parameters, idtype, fdtype = aux_data fields, n = children model = cls(parameters, idtype=idtype, fdtype=fdtype) - model.poisson_solver = poisson_solver model.fields.update({key: value for key, value in fields.items() if type(value) is not object}) if type(n) is not object: model.n = n @@ -938,18 +936,3 @@ def step(self): def steady_state_solve(self, *args, update=lambda model, *args: None, tol, max_it=10000): return super().steady_state_solve(*args, update=update, tol=tol, max_it=max_it, _min_n=1) - - def new(self, copy_prescribed=False): - model = super().new(copy_prescribed=copy_prescribed) - model.modified_helmholtz_solver = self.modified_helmholtz_solver - return model - - def flatten(self): - children, aux_data = super().flatten() - return children, aux_data + (self.modified_helmholtz_solver,) - - @classmethod - def unflatten(cls, aux_data, children): - model = super().unflatten(aux_data[:-1], children) - model.modified_helmholtz_solver = aux_data[-1] - return model