Skip to content

Commit

Permalink
Merge pull request #34 from jrmaddison/jrmaddison/single_precision
Browse files Browse the repository at this point in the history
More precision fixes
  • Loading branch information
jrmaddison authored Dec 2, 2024
2 parents 88c0ba4 + 8eaa9f5 commit 9c5e292
Showing 1 changed file with 9 additions and 5 deletions.
14 changes: 9 additions & 5 deletions bt_ocean/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -767,7 +767,7 @@ def new(self):
The new :class:`.Solver`.
"""

model = type(self)(self.parameters)
model = type(self)(self.parameters, idtype=self.grid.idtype, fdtype=self.grid.fdtype)
model.poisson_solver = self.poisson_solver
return model

Expand All @@ -794,17 +794,17 @@ def flatten(self):
"""

return ((dict(self.fields), self.n),
(self.parameters, self.poisson_solver))
(self.parameters, self.grid.idtype, self.grid.fdtype, self.poisson_solver))

@classmethod
def unflatten(cls, aux_data, children):
"""Unpack a JAX flattened representation.
"""

parameters, poisson_solver = aux_data
parameters, idtype, fdtype, poisson_solver = aux_data
fields, n = children

model = cls(parameters)
model = cls(parameters, idtype=idtype, fdtype=fdtype)
model.poisson_solver = poisson_solver
model.fields.update({key: value for key, value in fields.items() if type(value) is not object})
if type(n) is not object:
Expand All @@ -815,14 +815,18 @@ def unflatten(cls, aux_data, children):
def get_config(self):
return {"type": type(self).__name__,
"parameters": dict(self.parameters),
"idtype": jnp.dtype(self.grid.idtype).name,
"fdtype": jnp.dtype(self.grid.fdtype).name,
"fields": dict(self.fields),
"n": self.n}

@classmethod
def from_config(cls, config):
config = {key: keras.saving.deserialize_keras_object(value) for key, value in config.items()}
cls = cls._registry[config["type"]]
model = cls(config["parameters"])
idtype = jnp.dtype(config["idtype"]).type if "idtype" in config else None
fdtype = jnp.dtype(config["fdtype"]).type if "fdtype" in config else None
model = cls(config["parameters"], idtype=idtype, fdtype=fdtype)
model.fields.update(config["fields"])
model.n = config["n"]
return model
Expand Down

0 comments on commit 9c5e292

Please sign in to comment.