Skip to content

Commit

Permalink
Merge pull request #36 from jrmaddison/jrmaddison/initialize_prescribed
Browse files Browse the repository at this point in the history
Initialize prescribed fields in `Dynamics.call`
  • Loading branch information
jrmaddison authored Dec 3, 2024
2 parents f44e606 + e3ee9d6 commit fed075f
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 24 deletions.
46 changes: 23 additions & 23 deletions bt_ocean/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,7 +398,7 @@ def __init__(self, parameters, *, idtype=None, fdtype=None,
idtype=idtype, fdtype=fdtype)

self._fields = Fields(grid, field_keys)
self._prescribed_field_keys = set(prescribed_field_keys)
self._prescribed_field_keys = tuple(sorted(prescribed_field_keys))

self.zero_prescribed()
self.initialize()
Expand Down Expand Up @@ -476,6 +476,13 @@ def fields(self) -> Fields:

return self._fields

@property
def prescribed_field_keys(self) -> tuple:
"""Keys of prescribed fields.
"""

return self._prescribed_field_keys

@cached_property
def poisson_solver(self) -> PoissonSolver:
"""Solver for the Poisson equation.
Expand All @@ -487,7 +494,7 @@ def zero_prescribed(self):
"""Zero prescribed fields.
"""

self.fields.zero(*self._prescribed_field_keys)
self.fields.zero(*self.prescribed_field_keys)

@abstractmethod
def initialize(self, zeta=None):
Expand All @@ -501,7 +508,7 @@ def initialize(self, zeta=None):
field.
"""

self.fields.clear(keep_keys=self._prescribed_field_keys)
self.fields.clear(keep_keys=self.prescribed_field_keys)
self._n = 0

@abstractmethod
Expand Down Expand Up @@ -756,10 +763,17 @@ def read(cls, h, path="solver"):

return model

def new(self):
def new(self, *, copy_prescribed=False):
"""Return a new :class:`.Solver` with the same configuration as this
:class:`.Solver`.
Parameters
----------
copy_prescribed : bool
Whether to copy values of prescribed fields to the new
:class:`.Solver`.
Returns
-------
Expand All @@ -768,7 +782,9 @@ def new(self):
"""

model = type(self)(self.parameters, idtype=self.grid.idtype, fdtype=self.grid.fdtype)
model.poisson_solver = self.poisson_solver
if copy_prescribed:
for key in self.prescribed_field_keys:
model.fields[key] = self.fields[key]
return model

def update(self, model):
Expand All @@ -794,18 +810,17 @@ def flatten(self):
"""

return ((dict(self.fields), self.n),
(self.parameters, self.grid.idtype, self.grid.fdtype, self.poisson_solver))
(self.parameters, self.grid.idtype, self.grid.fdtype))

@classmethod
def unflatten(cls, aux_data, children):
"""Unpack a JAX flattened representation.
"""

parameters, idtype, fdtype, poisson_solver = aux_data
parameters, idtype, fdtype = aux_data
fields, n = children

model = cls(parameters, idtype=idtype, fdtype=fdtype)
model.poisson_solver = poisson_solver
model.fields.update({key: value for key, value in fields.items() if type(value) is not object})
if type(n) is not object:
model.n = n
Expand Down Expand Up @@ -921,18 +936,3 @@ def step(self):

def steady_state_solve(self, *args, update=lambda model, *args: None, tol, max_it=10000):
return super().steady_state_solve(*args, update=update, tol=tol, max_it=max_it, _min_n=1)

def new(self):
model = super().new()
model.modified_helmholtz_solver = self.modified_helmholtz_solver
return model

def flatten(self):
children, aux_data = super().flatten()
return children, aux_data + (self.modified_helmholtz_solver,)

@classmethod
def unflatten(cls, aux_data, children):
model = super().unflatten(aux_data[:-1], children)
model.modified_helmholtz_solver = aux_data[-1]
return model
2 changes: 1 addition & 1 deletion bt_ocean/network.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def compute_step_outputs(dynamics, _):
return dynamics, dynamics.fields["zeta"] * self.__output_weight

def compute_outputs(zeta):
dynamics = self.__dynamics.new()
dynamics = self.__dynamics.new(copy_prescribed=True)
dynamics.initialize(zeta * self.__input_weight)
_, outputs = jax.lax.scan(compute_step_outputs, dynamics, None, length=self.__n_output)
return outputs
Expand Down

0 comments on commit fed075f

Please sign in to comment.