From 128c7d4284376d7490290237c8d291e684885437 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 24 Oct 2023 11:27:11 +0100 Subject: [PATCH 1/3] Set real assumption for SymbolicFloat objects --- tlm_adjoint/overloaded_float.py | 9 +++++++-- 1 file changed, 7 insertions(+), 2 deletions(-) diff --git a/tlm_adjoint/overloaded_float.py b/tlm_adjoint/overloaded_float.py index b8e02fc2..f81427b2 100644 --- a/tlm_adjoint/overloaded_float.py +++ b/tlm_adjoint/overloaded_float.py @@ -376,8 +376,13 @@ def __init__(self, value=0.0, *, name=None, space_type="primal", else: self.assign(value, annotate=annotate, tlm=tlm) - def __new__(cls, *args, **kwargs): - return super().__new__(cls, new_symbol_name()) + def __new__(cls, *args, dtype=None, **kwargs): + if dtype is None: + dtype = _default_dtype + if issubclass(dtype, (float, np.floating)): + return super().__new__(cls, new_symbol_name(), real=True) + else: + return super().__new__(cls, new_symbol_name(), complex=True) def new(self, value=0.0, *, name=None, From 2d44fcb65eb1a02e230bc32196cda3609d6f6f80 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 24 Oct 2023 11:57:49 +0100 Subject: [PATCH 2/3] Further Float overloading, further JAX overloading tests --- tests/base/test_float.py | 16 +++++++--- tests/base/test_jax.py | 19 ++++++++---- tlm_adjoint/overloaded_float.py | 53 +++++++++++++++++++++++++++++++++ 3 files changed, 78 insertions(+), 10 deletions(-) diff --git a/tests/base/test_float.py b/tests/base/test_float.py index 97f98aff..0b9318cd 100644 --- a/tests/base/test_float.py +++ b/tests/base/test_float.py @@ -104,10 +104,16 @@ def forward(y): np.arccosh, np.arctanh, np.exp, + np.exp2, np.expm1, np.log, + np.log2, np.log10, - np.sqrt]) + np.log1p, + np.sqrt, + np.square, + np.cbrt, + np.reciprocal]) @seed_test def test_float_unary_overloading(setup_test, # noqa: F811 op): @@ -123,7 +129,7 @@ def forward(y): return (c - 1.0) ** 4 - if op is np.arccosh: + if op in {np.arccosh, np.reciprocal}: y = Float(1.1) else: y = Float(0.1) @@ -165,11 +171,13 @@ def forward(y): operator.mul, operator.truediv, operator.pow, - np.arctan2]) + np.arctan2, + np.hypot]) @seed_test def test_float_binary_overloading(setup_test, # noqa: F811 dtype, op): - if op is np.arctan2 and issubclass(dtype, (complex, np.complexfloating)): + if op in {np.arctan2, np.hypot} \ + and issubclass(dtype, (complex, np.complexfloating)): pytest.skip() set_default_float_dtype(dtype) diff --git a/tests/base/test_jax.py b/tests/base/test_jax.py index d3d771a7..ead987a1 100644 --- a/tests/base/test_jax.py +++ b/tests/base/test_jax.py @@ -98,10 +98,16 @@ def forward(y): np.arccosh, np.arctanh, np.exp, + np.exp2, np.expm1, np.log, + np.log2, np.log10, - np.sqrt]) + np.log1p, + np.sqrt, + np.square, + np.cbrt, + np.reciprocal]) @seed_test def test_jax_unary_overloading(setup_test, jax_tlm_config, # noqa: F811 op): @@ -119,7 +125,7 @@ def forward(y): return (c - 1.0) ** 4 - if op is np.arccosh: + if op in {np.arccosh, np.reciprocal}: y = np.array([1.1, 1.2], dtype=np.double) else: y = np.array([0.1, 0.2], dtype=np.double) @@ -160,16 +166,17 @@ def forward(y): operator.mul, operator.truediv, operator.pow, - np.arctan2]) + np.arctan2, + np.hypot]) @seed_test def test_jax_binary_overloading(setup_test, jax_tlm_config, # noqa: F811 dtype, op): + if op in {np.arctan2, np.hypot} \ + and issubclass(dtype, (complex, np.complexfloating)): + pytest.skip() set_default_float_dtype(dtype) set_default_jax_dtype(dtype) - if op is np.arctan2 and issubclass(dtype, (complex, np.complexfloating)): - pytest.skip() - def forward(y): x = y * y x = op(x, y) diff --git a/tlm_adjoint/overloaded_float.py b/tlm_adjoint/overloaded_float.py index f81427b2..abb1d158 100644 --- a/tlm_adjoint/overloaded_float.py +++ b/tlm_adjoint/overloaded_float.py @@ -580,6 +580,24 @@ def fdiff(self, argindex=1): return sp.exp(self.args[0]) +@register_function(np.log1p, "numpy.log1p") +class _tlm_adjoint__log1p(sp.Function): # noqa: N801 + def fdiff(self, argindex=1): + if argindex == 1: + return sp.Integer(1) / (sp.Integer(1) + self.args[0]) + + +@register_function(np.hypot, "numpy.hypot") +class _tlm_adjoint__hypot(sp.Function): # noqa: N801 + def fdiff(self, argindex=1): + if argindex == 1: + return self.args[0] / _tlm_adjoint__hypot(self.args[0], + self.args[1]) + elif argindex == 2: + return self.args[1] / _tlm_adjoint__hypot(self.args[0], + self.args[1]) + + class _tlm_adjoint__OverloadedFloat(np.lib.mixins.NDArrayOperatorsMixin, # noqa: E501,N801 SymbolicFloat): """A subclass of :class:`.SymbolicFloat` with operator overloading. @@ -702,6 +720,11 @@ def arctan(x): def arctan2(x1, x2): return sp.atan2(x1, x2) + @staticmethod + @register_operation(np.hypot) + def hypot(x1, x2): + return _tlm_adjoint__hypot(x1, x2) + @staticmethod @register_operation(np.sinh) def sinh(x): @@ -737,6 +760,11 @@ def arctanh(x): def exp(x): return sp.exp(x) + @staticmethod + @register_operation(np.exp2) + def exp2(x): + return 2 ** x + @staticmethod @register_operation(np.expm1) def expm1(x): @@ -747,16 +775,41 @@ def expm1(x): def log(x): return sp.log(x) + @staticmethod + @register_operation(np.log2) + def log2(x): + return sp.log(x, 2) + @staticmethod @register_operation(np.log10) def log10(x): return sp.log(x, 10) + @staticmethod + @register_operation(np.log1p) + def log1p(x): + return _tlm_adjoint__log1p(x) + @staticmethod @register_operation(np.sqrt) def sqrt(x): return sp.sqrt(x) + @staticmethod + @register_operation(np.square) + def square(x): + return x ** 2 + + @staticmethod + @register_operation(np.cbrt) + def cbrt(x): + return x ** sp.Rational(1, 3) + + @staticmethod + @register_operation(np.reciprocal) + def reciprocal(x): + return sp.Integer(1) / x + # Required by Sphinx class OverloadedFloat(_tlm_adjoint__OverloadedFloat): From 550d08148543e5fccb8e3e8d5bfd7d529de3baf9 Mon Sep 17 00:00:00 2001 From: "James R. Maddison" Date: Tue, 24 Oct 2023 12:35:16 +0100 Subject: [PATCH 3/3] Add dtype checks in FloatSpace and VectorSpace --- tlm_adjoint/jax.py | 4 ++++ tlm_adjoint/overloaded_float.py | 4 ++++ 2 files changed, 8 insertions(+) diff --git a/tlm_adjoint/jax.py b/tlm_adjoint/jax.py index 2600d302..8282c4b4 100644 --- a/tlm_adjoint/jax.py +++ b/tlm_adjoint/jax.py @@ -98,6 +98,10 @@ def __init__(self, n, *, dtype=None, comm=None): if comm is None: comm = DEFAULT_COMM + if not issubclass(dtype, (float, np.floating, + complex, np.complexfloating)): + raise TypeError("Invalid dtype") + self._n = n self._dtype = dtype self._comm = comm_dup_cached(comm) diff --git a/tlm_adjoint/overloaded_float.py b/tlm_adjoint/overloaded_float.py index abb1d158..22a8a9af 100644 --- a/tlm_adjoint/overloaded_float.py +++ b/tlm_adjoint/overloaded_float.py @@ -132,6 +132,10 @@ def __init__(self, float_cls=None, *, dtype=None, comm=None): if comm is None: comm = DEFAULT_COMM + if not issubclass(dtype, (float, np.floating, + complex, np.complexfloating)): + raise TypeError("Invalid dtype") + self._comm = comm_dup_cached(comm) self._dtype = dtype self._float_cls = float_cls