Skip to content

Commit

Permalink
Merge pull request #30 from jrmaddison/jrmaddison/single_precision_read
Browse files Browse the repository at this point in the history
Fix to `Solver.read` to allow single precision files to be loaded when using default double precision
  • Loading branch information
jrmaddison authored Dec 2, 2024
2 parents 1f147aa + a036ba1 commit 6667d00
Show file tree
Hide file tree
Showing 4 changed files with 78 additions and 13 deletions.
6 changes: 2 additions & 4 deletions bt_ocean/grid.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,11 +30,9 @@ class Grid:
N_y : Integral
Number of :math:`y`-dimension divisions.
idtype : type
Integer scalar data type. Defaults to `jax.numpy.int64` if 64-bit is
enabled, and `jax.numpy.int32` otherwise.
Integer scalar data type. Defaults to :func:`.default_idtype()`.
fdtype : type
Floating point scalar data type. Defaults to `jax.numpy.float64` if
64-bit is enabled, and `jax.numpy.float32` otherwise.
Floating point scalar data type. Defaults to :func:`.default_fdtype()`.
"""

def __init__(self, L_x, L_y, N_x, N_y, *, idtype=None, fdtype=None):
Expand Down
26 changes: 19 additions & 7 deletions bt_ocean/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from .fft import dst
from .grid import Grid
from .inversion import ModifiedHelmholtzSolver, PoissonSolver
from .precision import default_idtype, default_fdtype

__all__ = \
[
Expand Down Expand Up @@ -348,6 +349,10 @@ class Solver(ABC):
- `'r'` : Linear drag coefficient, :math:`r`.
- `dt` : Time step size.
idtype : type
Integer scalar data type. Defaults to :func:`.default_idtype()`.
fdtype : type
Floating point scalar data type. Defaults to :func:`.default_fdtype()`.
field_keys : Iterable
Keys for fields. The following keys are added by default
Expand All @@ -372,8 +377,13 @@ class Solver(ABC):
"nu": required,
"dt": required}

def __init__(self, parameters, *, field_keys=None, prescribed_field_keys=None):
def __init__(self, parameters, *, idtype=None, fdtype=None,
field_keys=None, prescribed_field_keys=None):
self._parameters = parameters = Parameters(parameters, defaults=self._defaults)
if idtype is None:
idtype = default_idtype()
if fdtype is None:
fdtype = default_fdtype()
if field_keys is None:
field_keys = set()
else:
Expand All @@ -384,7 +394,8 @@ def __init__(self, parameters, *, field_keys=None, prescribed_field_keys=None):

self._grid = grid = Grid(
parameters["L_x"], parameters["L_y"],
parameters["N_x"], parameters["N_y"])
parameters["N_x"], parameters["N_y"],
idtype=idtype, fdtype=fdtype)

self._fields = Fields(grid, field_keys)
self._prescribed_field_keys = set(prescribed_field_keys)
Expand Down Expand Up @@ -737,7 +748,9 @@ def read(cls, h, path="solver"):
del h

cls = cls._registry[g.attrs["type"]]
model = cls(Parameters.read(g, "parameters"))
idtype = jnp.dtype(g["fields"].attrs["idtype"]).type
fdtype = jnp.dtype(g["fields"].attrs["fdtype"]).type
model = cls(Parameters.read(g, "parameters"), idtype=idtype, fdtype=fdtype)
model.fields.update(Fields.read(g, "fields", grid=model.grid))
model.n = g.attrs["n"]

Expand Down Expand Up @@ -829,13 +842,12 @@ class CNAB2Solver(Solver):
Parameters
----------
parameters : :class:`.Parameters`
Model parameters. See :class:`.Solver`.
See :class:`.Solver`.
"""

def __init__(self, parameters):
def __init__(self, parameters, *, idtype=None, fdtype=None):
super().__init__(
parameters,
parameters, idtype=idtype, fdtype=fdtype,
field_keys={"F_1"})

@cached_property
Expand Down
23 changes: 21 additions & 2 deletions bt_ocean/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,32 @@

__all__ = \
[
"x64_disabled",
"x64_enabled",

"default_idtype",
"default_fdtype"
]


@contextmanager
def x64_disabled():
"""Context manager for temporarily disabling the `'jax_enable_x64'` JAX
configuration option, and for temporarily setting the Keras default float
type to single precision.
"""

x64_enabled = jax.config.x64_enabled
floatx = keras.backend.floatx()
try:
jax.config.update("jax_enable_x64", False)
keras.backend.set_floatx("float32")
yield
finally:
jax.config.update("jax_enable_x64", x64_enabled)
keras.backend.set_floatx(floatx)


@contextmanager
def x64_enabled():
"""Context manager for temporarily enabling the `'jax_enable_x64'` JAX
Expand All @@ -24,10 +43,10 @@ def x64_enabled():
"""

x64_enabled = jax.config.x64_enabled
jax.config.update("jax_enable_x64", True)
floatx = keras.backend.floatx()
keras.backend.set_floatx("float64")
try:
jax.config.update("jax_enable_x64", True)
keras.backend.set_floatx("float64")
yield
finally:
jax.config.update("jax_enable_x64", x64_enabled)
Expand Down
36 changes: 36 additions & 0 deletions tests/test_io.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from bt_ocean.grid import Grid
from bt_ocean.model import CNAB2Solver, Fields, Parameters, Solver
from bt_ocean.parameters import parameters, Q
from bt_ocean.precision import default_fdtype, x64_disabled

import jax.numpy as jnp
import pytest
import zarr

from .test_base import test_precision # noqa: F401
Expand Down Expand Up @@ -90,3 +92,37 @@ def test_solver_roundtrip(tmp_path):
assert set(input_model.fields) == set(model.fields)
for key, value in model.fields.items():
assert (input_model.fields[key] == value).all()


def test_solver_roundtrip_precision_change(tmp_path):
if default_fdtype() != jnp.float64:
pytest.skip("Double precision only")

with x64_disabled():
model = CNAB2Solver(model_parameters())
model.fields["Q"] = Q(model.grid)
model.steps(5)

filename = tmp_path / "tmp.zarr"
with zarr.open(filename, "w") as h:
model.write(h)
with zarr.open(filename, "r") as h:
input_model = Solver.read(h)

assert type(input_model) is type(model)
assert input_model.n == model.n

assert input_model.grid.L_x == model.grid.L_x
assert input_model.grid.L_y == model.grid.L_y
assert input_model.grid.N_x == model.grid.N_x
assert input_model.grid.N_y == model.grid.N_y
assert input_model.grid.idtype == model.grid.idtype
assert input_model.grid.fdtype == model.grid.fdtype

assert set(input_model.parameters) == set(model.parameters)
for key, value in model.parameters.items():
assert input_model.parameters[key] == value

assert set(input_model.fields) == set(model.fields)
for key, value in model.fields.items():
assert (input_model.fields[key] == value).all()

0 comments on commit 6667d00

Please sign in to comment.