From f925e91375e20e4199d26f9a95c29b8d2bf5adf0 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Mon, 26 Aug 2024 18:53:19 +0100 Subject: [PATCH] Do not set the jax_enable_x64 JAX configuration option on import --- tests/test_base.py | 8 ++++++++ tlm_adjoint/jax.py | 37 ++++++++++++++++++++++--------------- 2 files changed, 30 insertions(+), 15 deletions(-) diff --git a/tests/test_base.py b/tests/test_base.py index 96f56674..e0f3b81f 100644 --- a/tests/test_base.py +++ b/tests/test_base.py @@ -4,6 +4,10 @@ import functools import hashlib import inspect +try: + import jax +except ModuleNotFoundError: + jax = None import json import numpy as np from operator import itemgetter @@ -20,6 +24,10 @@ ] +if jax is not None: + jax.config.update("jax_enable_x64", True) + + def seed_test(fn): @patch_function(fn) def wrapped_fn(orig, orig_args, *args, **kwargs): diff --git a/tlm_adjoint/jax.py b/tlm_adjoint/jax.py index 27c3ea18..907f93a6 100644 --- a/tlm_adjoint/jax.py +++ b/tlm_adjoint/jax.py @@ -23,7 +23,6 @@ try: import jax - jax.config.update("jax_enable_x64", True) except ModuleNotFoundError: jax = None @@ -42,28 +41,36 @@ "to_jax" ] -try: - import petsc4py.PETSc as PETSc - _default_dtype = PETSc.ScalarType -except ModuleNotFoundError: - _default_dtype = np.double -_default_dtype = np.dtype(_default_dtype).type -if not issubclass(_default_dtype, (np.floating, np.complexfloating)): - raise ImportError("Invalid default dtype") +_default_dtype = None + + +def default_dtype(): + if _default_dtype is None: + return jax.numpy.array(0.0).dtype.type + else: + return _default_dtype def set_default_jax_dtype(dtype): """Set the default data type used by :class:`.Vector` objects. - :arg dtype: The default data type. + Parameters + ---------- + + dtype : type + The default data type. If `None` then the default JAX floating point + scalar data type is used. """ global _default_dtype - dtype = np.dtype(dtype).type - if not issubclass(dtype, (np.floating, np.complexfloating)): - raise TypeError("Invalid dtype") - _default_dtype = dtype + if dtype is None: + _default_dtype = None + else: + dtype = np.dtype(dtype).type + if not issubclass(dtype, (np.floating, np.complexfloating)): + raise TypeError("Invalid dtype") + _default_dtype = dtype class VectorSpaceInterface(SpaceInterface): @@ -98,7 +105,7 @@ class VectorSpace: def __init__(self, n, *, dtype=None, comm=None): if dtype is None: - dtype = _default_dtype + dtype = default_dtype() if comm is None: comm = DEFAULT_COMM comm = comm_dup_cached(comm)