Skip to content

Commit

Permalink
Merge pull request #596 from tlm-adjoint/jrmaddison/jax_dtype
Browse files Browse the repository at this point in the history
Do not set the jax_enable_x64 JAX configuration option on import
  • Loading branch information
jrmaddison authored Aug 26, 2024
2 parents 63c5dad + f925e91 commit c372022
Show file tree
Hide file tree
Showing 2 changed files with 30 additions and 15 deletions.
8 changes: 8 additions & 0 deletions tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@
import functools
import hashlib
import inspect
try:
import jax
except ModuleNotFoundError:
jax = None
import json
import numpy as np
from operator import itemgetter
Expand All @@ -20,6 +24,10 @@
]


if jax is not None:
jax.config.update("jax_enable_x64", True)


def seed_test(fn):
@patch_function(fn)
def wrapped_fn(orig, orig_args, *args, **kwargs):
Expand Down
37 changes: 22 additions & 15 deletions tlm_adjoint/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

try:
import jax
jax.config.update("jax_enable_x64", True)
except ModuleNotFoundError:
jax = None

Expand All @@ -42,28 +41,36 @@
"to_jax"
]

try:
import petsc4py.PETSc as PETSc
_default_dtype = PETSc.ScalarType
except ModuleNotFoundError:
_default_dtype = np.double
_default_dtype = np.dtype(_default_dtype).type
if not issubclass(_default_dtype, (np.floating, np.complexfloating)):
raise ImportError("Invalid default dtype")
_default_dtype = None


def default_dtype():
if _default_dtype is None:
return jax.numpy.array(0.0).dtype.type
else:
return _default_dtype


def set_default_jax_dtype(dtype):
"""Set the default data type used by :class:`.Vector` objects.
:arg dtype: The default data type.
Parameters
----------
dtype : type
The default data type. If `None` then the default JAX floating point
scalar data type is used.
"""

global _default_dtype

dtype = np.dtype(dtype).type
if not issubclass(dtype, (np.floating, np.complexfloating)):
raise TypeError("Invalid dtype")
_default_dtype = dtype
if dtype is None:
_default_dtype = None
else:
dtype = np.dtype(dtype).type
if not issubclass(dtype, (np.floating, np.complexfloating)):
raise TypeError("Invalid dtype")
_default_dtype = dtype


class VectorSpaceInterface(SpaceInterface):
Expand Down Expand Up @@ -98,7 +105,7 @@ class VectorSpace:

def __init__(self, n, *, dtype=None, comm=None):
if dtype is None:
dtype = _default_dtype
dtype = default_dtype()
if comm is None:
comm = DEFAULT_COMM
comm = comm_dup_cached(comm)
Expand Down

0 comments on commit c372022

Please sign in to comment.