Skip to content

Commit

Permalink
Further Float overloading, further JAX overloading tests
Browse files Browse the repository at this point in the history
  • Loading branch information
jrmaddison committed Oct 24, 2023
1 parent 128c7d4 commit 2d44fcb
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 10 deletions.
16 changes: 12 additions & 4 deletions tests/base/test_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,16 @@ def forward(y):
np.arccosh,
np.arctanh,
np.exp,
np.exp2,
np.expm1,
np.log,
np.log2,
np.log10,
np.sqrt])
np.log1p,
np.sqrt,
np.square,
np.cbrt,
np.reciprocal])
@seed_test
def test_float_unary_overloading(setup_test, # noqa: F811
op):
Expand All @@ -123,7 +129,7 @@ def forward(y):

return (c - 1.0) ** 4

if op is np.arccosh:
if op in {np.arccosh, np.reciprocal}:
y = Float(1.1)
else:
y = Float(0.1)
Expand Down Expand Up @@ -165,11 +171,13 @@ def forward(y):
operator.mul,
operator.truediv,
operator.pow,
np.arctan2])
np.arctan2,
np.hypot])
@seed_test
def test_float_binary_overloading(setup_test, # noqa: F811
dtype, op):
if op is np.arctan2 and issubclass(dtype, (complex, np.complexfloating)):
if op in {np.arctan2, np.hypot} \
and issubclass(dtype, (complex, np.complexfloating)):
pytest.skip()
set_default_float_dtype(dtype)

Expand Down
19 changes: 13 additions & 6 deletions tests/base/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,16 @@ def forward(y):
np.arccosh,
np.arctanh,
np.exp,
np.exp2,
np.expm1,
np.log,
np.log2,
np.log10,
np.sqrt])
np.log1p,
np.sqrt,
np.square,
np.cbrt,
np.reciprocal])
@seed_test
def test_jax_unary_overloading(setup_test, jax_tlm_config, # noqa: F811
op):
Expand All @@ -119,7 +125,7 @@ def forward(y):

return (c - 1.0) ** 4

if op is np.arccosh:
if op in {np.arccosh, np.reciprocal}:
y = np.array([1.1, 1.2], dtype=np.double)
else:
y = np.array([0.1, 0.2], dtype=np.double)
Expand Down Expand Up @@ -160,16 +166,17 @@ def forward(y):
operator.mul,
operator.truediv,
operator.pow,
np.arctan2])
np.arctan2,
np.hypot])
@seed_test
def test_jax_binary_overloading(setup_test, jax_tlm_config, # noqa: F811
dtype, op):
if op in {np.arctan2, np.hypot} \
and issubclass(dtype, (complex, np.complexfloating)):
pytest.skip()
set_default_float_dtype(dtype)
set_default_jax_dtype(dtype)

if op is np.arctan2 and issubclass(dtype, (complex, np.complexfloating)):
pytest.skip()

def forward(y):
x = y * y
x = op(x, y)
Expand Down
53 changes: 53 additions & 0 deletions tlm_adjoint/overloaded_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -580,6 +580,24 @@ def fdiff(self, argindex=1):
return sp.exp(self.args[0])


@register_function(np.log1p, "numpy.log1p")
class _tlm_adjoint__log1p(sp.Function): # noqa: N801
def fdiff(self, argindex=1):
if argindex == 1:
return sp.Integer(1) / (sp.Integer(1) + self.args[0])


@register_function(np.hypot, "numpy.hypot")
class _tlm_adjoint__hypot(sp.Function): # noqa: N801
def fdiff(self, argindex=1):
if argindex == 1:
return self.args[0] / _tlm_adjoint__hypot(self.args[0],
self.args[1])
elif argindex == 2:
return self.args[1] / _tlm_adjoint__hypot(self.args[0],
self.args[1])


class _tlm_adjoint__OverloadedFloat(np.lib.mixins.NDArrayOperatorsMixin, # noqa: E501,N801
SymbolicFloat):
"""A subclass of :class:`.SymbolicFloat` with operator overloading.
Expand Down Expand Up @@ -702,6 +720,11 @@ def arctan(x):
def arctan2(x1, x2):
return sp.atan2(x1, x2)

@staticmethod
@register_operation(np.hypot)
def hypot(x1, x2):
return _tlm_adjoint__hypot(x1, x2)

@staticmethod
@register_operation(np.sinh)
def sinh(x):
Expand Down Expand Up @@ -737,6 +760,11 @@ def arctanh(x):
def exp(x):
return sp.exp(x)

@staticmethod
@register_operation(np.exp2)
def exp2(x):
return 2 ** x

@staticmethod
@register_operation(np.expm1)
def expm1(x):
Expand All @@ -747,16 +775,41 @@ def expm1(x):
def log(x):
return sp.log(x)

@staticmethod
@register_operation(np.log2)
def log2(x):
return sp.log(x, 2)

@staticmethod
@register_operation(np.log10)
def log10(x):
return sp.log(x, 10)

@staticmethod
@register_operation(np.log1p)
def log1p(x):
return _tlm_adjoint__log1p(x)

@staticmethod
@register_operation(np.sqrt)
def sqrt(x):
return sp.sqrt(x)

@staticmethod
@register_operation(np.square)
def square(x):
return x ** 2

@staticmethod
@register_operation(np.cbrt)
def cbrt(x):
return x ** sp.Rational(1, 3)

@staticmethod
@register_operation(np.reciprocal)
def reciprocal(x):
return sp.Integer(1) / x


# Required by Sphinx
class OverloadedFloat(_tlm_adjoint__OverloadedFloat):
Expand Down

0 comments on commit 2d44fcb

Please sign in to comment.