diff --git a/tests/base/test_float.py b/tests/base/test_float.py index 97f98aff..0b9318cd 100644 --- a/tests/base/test_float.py +++ b/tests/base/test_float.py @@ -104,10 +104,16 @@ def forward(y): np.arccosh, np.arctanh, np.exp, + np.exp2, np.expm1, np.log, + np.log2, np.log10, - np.sqrt]) + np.log1p, + np.sqrt, + np.square, + np.cbrt, + np.reciprocal]) @seed_test def test_float_unary_overloading(setup_test, # noqa: F811 op): @@ -123,7 +129,7 @@ def forward(y): return (c - 1.0) ** 4 - if op is np.arccosh: + if op in {np.arccosh, np.reciprocal}: y = Float(1.1) else: y = Float(0.1) @@ -165,11 +171,13 @@ def forward(y): operator.mul, operator.truediv, operator.pow, - np.arctan2]) + np.arctan2, + np.hypot]) @seed_test def test_float_binary_overloading(setup_test, # noqa: F811 dtype, op): - if op is np.arctan2 and issubclass(dtype, (complex, np.complexfloating)): + if op in {np.arctan2, np.hypot} \ + and issubclass(dtype, (complex, np.complexfloating)): pytest.skip() set_default_float_dtype(dtype) diff --git a/tests/base/test_jax.py b/tests/base/test_jax.py index d3d771a7..ead987a1 100644 --- a/tests/base/test_jax.py +++ b/tests/base/test_jax.py @@ -98,10 +98,16 @@ def forward(y): np.arccosh, np.arctanh, np.exp, + np.exp2, np.expm1, np.log, + np.log2, np.log10, - np.sqrt]) + np.log1p, + np.sqrt, + np.square, + np.cbrt, + np.reciprocal]) @seed_test def test_jax_unary_overloading(setup_test, jax_tlm_config, # noqa: F811 op): @@ -119,7 +125,7 @@ def forward(y): return (c - 1.0) ** 4 - if op is np.arccosh: + if op in {np.arccosh, np.reciprocal}: y = np.array([1.1, 1.2], dtype=np.double) else: y = np.array([0.1, 0.2], dtype=np.double) @@ -160,16 +166,17 @@ def forward(y): operator.mul, operator.truediv, operator.pow, - np.arctan2]) + np.arctan2, + np.hypot]) @seed_test def test_jax_binary_overloading(setup_test, jax_tlm_config, # noqa: F811 dtype, op): + if op in {np.arctan2, np.hypot} \ + and issubclass(dtype, (complex, np.complexfloating)): + pytest.skip() set_default_float_dtype(dtype) set_default_jax_dtype(dtype) - if op is np.arctan2 and issubclass(dtype, (complex, np.complexfloating)): - pytest.skip() - def forward(y): x = y * y x = op(x, y) diff --git a/tlm_adjoint/overloaded_float.py b/tlm_adjoint/overloaded_float.py index f81427b2..abb1d158 100644 --- a/tlm_adjoint/overloaded_float.py +++ b/tlm_adjoint/overloaded_float.py @@ -580,6 +580,24 @@ def fdiff(self, argindex=1): return sp.exp(self.args[0]) +@register_function(np.log1p, "numpy.log1p") +class _tlm_adjoint__log1p(sp.Function): # noqa: N801 + def fdiff(self, argindex=1): + if argindex == 1: + return sp.Integer(1) / (sp.Integer(1) + self.args[0]) + + +@register_function(np.hypot, "numpy.hypot") +class _tlm_adjoint__hypot(sp.Function): # noqa: N801 + def fdiff(self, argindex=1): + if argindex == 1: + return self.args[0] / _tlm_adjoint__hypot(self.args[0], + self.args[1]) + elif argindex == 2: + return self.args[1] / _tlm_adjoint__hypot(self.args[0], + self.args[1]) + + class _tlm_adjoint__OverloadedFloat(np.lib.mixins.NDArrayOperatorsMixin, # noqa: E501,N801 SymbolicFloat): """A subclass of :class:`.SymbolicFloat` with operator overloading. @@ -702,6 +720,11 @@ def arctan(x): def arctan2(x1, x2): return sp.atan2(x1, x2) + @staticmethod + @register_operation(np.hypot) + def hypot(x1, x2): + return _tlm_adjoint__hypot(x1, x2) + @staticmethod @register_operation(np.sinh) def sinh(x): @@ -737,6 +760,11 @@ def arctanh(x): def exp(x): return sp.exp(x) + @staticmethod + @register_operation(np.exp2) + def exp2(x): + return 2 ** x + @staticmethod @register_operation(np.expm1) def expm1(x): @@ -747,16 +775,41 @@ def expm1(x): def log(x): return sp.log(x) + @staticmethod + @register_operation(np.log2) + def log2(x): + return sp.log(x, 2) + @staticmethod @register_operation(np.log10) def log10(x): return sp.log(x, 10) + @staticmethod + @register_operation(np.log1p) + def log1p(x): + return _tlm_adjoint__log1p(x) + @staticmethod @register_operation(np.sqrt) def sqrt(x): return sp.sqrt(x) + @staticmethod + @register_operation(np.square) + def square(x): + return x ** 2 + + @staticmethod + @register_operation(np.cbrt) + def cbrt(x): + return x ** sp.Rational(1, 3) + + @staticmethod + @register_operation(np.reciprocal) + def reciprocal(x): + return sp.Integer(1) / x + # Required by Sphinx class OverloadedFloat(_tlm_adjoint__OverloadedFloat):