diff --git a/bt_ocean/model.py b/bt_ocean/model.py index 0365637..c6ebe5a 100644 --- a/bt_ocean/model.py +++ b/bt_ocean/model.py @@ -398,7 +398,7 @@ def __init__(self, parameters, *, idtype=None, fdtype=None, idtype=idtype, fdtype=fdtype) self._fields = Fields(grid, field_keys) - self._prescribed_field_keys = set(prescribed_field_keys) + self._prescribed_field_keys = tuple(sorted(prescribed_field_keys)) self.zero_prescribed() self.initialize() @@ -476,6 +476,13 @@ def fields(self) -> Fields: return self._fields + @property + def prescribed_field_keys(self) -> tuple: + """Keys of prescribed fields. + """ + + return self._prescribed_field_keys + @cached_property def poisson_solver(self) -> PoissonSolver: """Solver for the Poisson equation. @@ -487,7 +494,7 @@ def zero_prescribed(self): """Zero prescribed fields. """ - self.fields.zero(*self._prescribed_field_keys) + self.fields.zero(*self.prescribed_field_keys) @abstractmethod def initialize(self, zeta=None): @@ -501,7 +508,7 @@ def initialize(self, zeta=None): field. """ - self.fields.clear(keep_keys=self._prescribed_field_keys) + self.fields.clear(keep_keys=self.prescribed_field_keys) self._n = 0 @abstractmethod @@ -756,10 +763,17 @@ def read(cls, h, path="solver"): return model - def new(self): + def new(self, *, copy_prescribed=False): """Return a new :class:`.Solver` with the same configuration as this :class:`.Solver`. + Parameters + ---------- + + copy_prescribed : bool + Whether to copy values of prescribed fields to the new + :class:`.Solver`. + Returns ------- @@ -768,7 +782,9 @@ def new(self): """ model = type(self)(self.parameters, idtype=self.grid.idtype, fdtype=self.grid.fdtype) - model.poisson_solver = self.poisson_solver + if copy_prescribed: + for key in self.prescribed_field_keys: + model.fields[key] = self.fields[key] return model def update(self, model): @@ -794,18 +810,17 @@ def flatten(self): """ return ((dict(self.fields), self.n), - (self.parameters, self.grid.idtype, self.grid.fdtype, self.poisson_solver)) + (self.parameters, self.grid.idtype, self.grid.fdtype)) @classmethod def unflatten(cls, aux_data, children): """Unpack a JAX flattened representation. """ - parameters, idtype, fdtype, poisson_solver = aux_data + parameters, idtype, fdtype = aux_data fields, n = children model = cls(parameters, idtype=idtype, fdtype=fdtype) - model.poisson_solver = poisson_solver model.fields.update({key: value for key, value in fields.items() if type(value) is not object}) if type(n) is not object: model.n = n @@ -921,18 +936,3 @@ def step(self): def steady_state_solve(self, *args, update=lambda model, *args: None, tol, max_it=10000): return super().steady_state_solve(*args, update=update, tol=tol, max_it=max_it, _min_n=1) - - def new(self): - model = super().new() - model.modified_helmholtz_solver = self.modified_helmholtz_solver - return model - - def flatten(self): - children, aux_data = super().flatten() - return children, aux_data + (self.modified_helmholtz_solver,) - - @classmethod - def unflatten(cls, aux_data, children): - model = super().unflatten(aux_data[:-1], children) - model.modified_helmholtz_solver = aux_data[-1] - return model diff --git a/bt_ocean/network.py b/bt_ocean/network.py index 48f9286..75b9db0 100644 --- a/bt_ocean/network.py +++ b/bt_ocean/network.py @@ -177,7 +177,7 @@ def compute_step_outputs(dynamics, _): return dynamics, dynamics.fields["zeta"] * self.__output_weight def compute_outputs(zeta): - dynamics = self.__dynamics.new() + dynamics = self.__dynamics.new(copy_prescribed=True) dynamics.initialize(zeta * self.__input_weight) _, outputs = jax.lax.scan(compute_step_outputs, dynamics, None, length=self.__n_output) return outputs