diff --git a/bt_ocean/grid.py b/bt_ocean/grid.py index e1c5f56..23baf2b 100644 --- a/bt_ocean/grid.py +++ b/bt_ocean/grid.py @@ -30,11 +30,9 @@ class Grid: N_y : Integral Number of :math:`y`-dimension divisions. idtype : type - Integer scalar data type. Defaults to `jax.numpy.int64` if 64-bit is - enabled, and `jax.numpy.int32` otherwise. + Integer scalar data type. Defaults to :func:`.default_idtype()`. fdtype : type - Floating point scalar data type. Defaults to `jax.numpy.float64` if - 64-bit is enabled, and `jax.numpy.float32` otherwise. + Floating point scalar data type. Defaults to :func:`.default_fdtype()`. """ def __init__(self, L_x, L_y, N_x, N_y, *, idtype=None, fdtype=None): diff --git a/bt_ocean/model.py b/bt_ocean/model.py index 59bad0d..df600e9 100644 --- a/bt_ocean/model.py +++ b/bt_ocean/model.py @@ -14,6 +14,7 @@ from .fft import dst from .grid import Grid from .inversion import ModifiedHelmholtzSolver, PoissonSolver +from .precision import default_idtype, default_fdtype __all__ = \ [ @@ -348,6 +349,10 @@ class Solver(ABC): - `'r'` : Linear drag coefficient, :math:`r`. - `dt` : Time step size. + idtype : type + Integer scalar data type. Defaults to :func:`.default_idtype()`. + fdtype : type + Floating point scalar data type. Defaults to :func:`.default_fdtype()`. field_keys : Iterable Keys for fields. The following keys are added by default @@ -372,8 +377,13 @@ class Solver(ABC): "nu": required, "dt": required} - def __init__(self, parameters, *, field_keys=None, prescribed_field_keys=None): + def __init__(self, parameters, *, idtype=None, fdtype=None, + field_keys=None, prescribed_field_keys=None): self._parameters = parameters = Parameters(parameters, defaults=self._defaults) + if idtype is None: + idtype = default_idtype() + if fdtype is None: + fdtype = default_fdtype() if field_keys is None: field_keys = set() else: @@ -384,7 +394,8 @@ def __init__(self, parameters, *, field_keys=None, prescribed_field_keys=None): self._grid = grid = Grid( parameters["L_x"], parameters["L_y"], - parameters["N_x"], parameters["N_y"]) + parameters["N_x"], parameters["N_y"], + idtype=idtype, fdtype=fdtype) self._fields = Fields(grid, field_keys) self._prescribed_field_keys = set(prescribed_field_keys) @@ -737,7 +748,9 @@ def read(cls, h, path="solver"): del h cls = cls._registry[g.attrs["type"]] - model = cls(Parameters.read(g, "parameters")) + idtype = jnp.dtype(g["fields"].attrs["idtype"]).type + fdtype = jnp.dtype(g["fields"].attrs["fdtype"]).type + model = cls(Parameters.read(g, "parameters"), idtype=idtype, fdtype=fdtype) model.fields.update(Fields.read(g, "fields", grid=model.grid)) model.n = g.attrs["n"] @@ -829,13 +842,12 @@ class CNAB2Solver(Solver): Parameters ---------- - parameters : :class:`.Parameters` - Model parameters. See :class:`.Solver`. + See :class:`.Solver`. """ - def __init__(self, parameters): + def __init__(self, parameters, *, idtype=None, fdtype=None): super().__init__( - parameters, + parameters, idtype=idtype, fdtype=fdtype, field_keys={"F_1"}) @cached_property diff --git a/bt_ocean/precision.py b/bt_ocean/precision.py index 5daa2ea..ab89350 100644 --- a/bt_ocean/precision.py +++ b/bt_ocean/precision.py @@ -9,6 +9,7 @@ __all__ = \ [ + "x64_disabled", "x64_enabled", "default_idtype", @@ -16,6 +17,24 @@ ] +@contextmanager +def x64_disabled(): + """Context manager for temporarily disabling the `'jax_enable_x64'` JAX + configuration option, and for temporarily setting the Keras default float + type to single precision. + """ + + x64_enabled = jax.config.x64_enabled + floatx = keras.backend.floatx() + try: + jax.config.update("jax_enable_x64", False) + keras.backend.set_floatx("float32") + yield + finally: + jax.config.update("jax_enable_x64", x64_enabled) + keras.backend.set_floatx(floatx) + + @contextmanager def x64_enabled(): """Context manager for temporarily enabling the `'jax_enable_x64'` JAX @@ -24,10 +43,10 @@ def x64_enabled(): """ x64_enabled = jax.config.x64_enabled - jax.config.update("jax_enable_x64", True) floatx = keras.backend.floatx() - keras.backend.set_floatx("float64") try: + jax.config.update("jax_enable_x64", True) + keras.backend.set_floatx("float64") yield finally: jax.config.update("jax_enable_x64", x64_enabled) diff --git a/tests/test_io.py b/tests/test_io.py index 29d8b5f..0c0a42b 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -1,8 +1,10 @@ from bt_ocean.grid import Grid from bt_ocean.model import CNAB2Solver, Fields, Parameters, Solver from bt_ocean.parameters import parameters, Q +from bt_ocean.precision import default_fdtype, x64_disabled import jax.numpy as jnp +import pytest import zarr from .test_base import test_precision # noqa: F401 @@ -90,3 +92,37 @@ def test_solver_roundtrip(tmp_path): assert set(input_model.fields) == set(model.fields) for key, value in model.fields.items(): assert (input_model.fields[key] == value).all() + + +def test_solver_roundtrip_precision_change(tmp_path): + if default_fdtype() != jnp.float64: + pytest.skip("Double precision only") + + with x64_disabled(): + model = CNAB2Solver(model_parameters()) + model.fields["Q"] = Q(model.grid) + model.steps(5) + + filename = tmp_path / "tmp.zarr" + with zarr.open(filename, "w") as h: + model.write(h) + with zarr.open(filename, "r") as h: + input_model = Solver.read(h) + + assert type(input_model) is type(model) + assert input_model.n == model.n + + assert input_model.grid.L_x == model.grid.L_x + assert input_model.grid.L_y == model.grid.L_y + assert input_model.grid.N_x == model.grid.N_x + assert input_model.grid.N_y == model.grid.N_y + assert input_model.grid.idtype == model.grid.idtype + assert input_model.grid.fdtype == model.grid.fdtype + + assert set(input_model.parameters) == set(model.parameters) + for key, value in model.parameters.items(): + assert input_model.parameters[key] == value + + assert set(input_model.fields) == set(model.fields) + for key, value in model.fields.items(): + assert (input_model.fields[key] == value).all()