From 8eaa9f5a73b77fe3be9ce9b70a11f999d03ce960 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Mon, 2 Dec 2024 18:37:35 +0000 Subject: [PATCH] More precision fixes --- bt_ocean/model.py | 14 +++++++++----- 1 file changed, 9 insertions(+), 5 deletions(-) diff --git a/bt_ocean/model.py b/bt_ocean/model.py index 1366f24..2ea8ff8 100644 --- a/bt_ocean/model.py +++ b/bt_ocean/model.py @@ -767,7 +767,7 @@ def new(self): The new :class:`.Solver`. """ - model = type(self)(self.parameters) + model = type(self)(self.parameters, idtype=self.grid.idtype, fdtype=self.grid.fdtype) model.poisson_solver = self.poisson_solver return model @@ -794,17 +794,17 @@ def flatten(self): """ return ((dict(self.fields), self.n), - (self.parameters, self.poisson_solver)) + (self.parameters, self.grid.idtype, self.grid.fdtype, self.poisson_solver)) @classmethod def unflatten(cls, aux_data, children): """Unpack a JAX flattened representation. """ - parameters, poisson_solver = aux_data + parameters, idtype, fdtype, poisson_solver = aux_data fields, n = children - model = cls(parameters) + model = cls(parameters, idtype=idtype, fdtype=fdtype) model.poisson_solver = poisson_solver model.fields.update({key: value for key, value in fields.items() if type(value) is not object}) if type(n) is not object: @@ -815,6 +815,8 @@ def unflatten(cls, aux_data, children): def get_config(self): return {"type": type(self).__name__, "parameters": dict(self.parameters), + "idtype": jnp.dtype(self.grid.idtype).name, + "fdtype": jnp.dtype(self.grid.fdtype).name, "fields": dict(self.fields), "n": self.n} @@ -822,7 +824,9 @@ def get_config(self): def from_config(cls, config): config = {key: keras.saving.deserialize_keras_object(value) for key, value in config.items()} cls = cls._registry[config["type"]] - model = cls(config["parameters"]) + idtype = jnp.dtype(config["idtype"]).type if "idtype" in config else None + fdtype = jnp.dtype(config["fdtype"]).type if "fdtype" in config else None + model = cls(config["parameters"], idtype=idtype, fdtype=fdtype) model.fields.update(config["fields"]) model.n = config["n"] return model