Skip to content

Commit

Permalink
Merge pull request #419 from tlm-adjoint/jrmaddison/float
Browse files Browse the repository at this point in the history
`SymbolicFloat` updates
  • Loading branch information
jrmaddison authored Oct 24, 2023
2 parents 52c96b6 + 550d081 commit 0b8fee1
Show file tree
Hide file tree
Showing 4 changed files with 93 additions and 12 deletions.
16 changes: 12 additions & 4 deletions tests/base/test_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,16 @@ def forward(y):
np.arccosh,
np.arctanh,
np.exp,
np.exp2,
np.expm1,
np.log,
np.log2,
np.log10,
np.sqrt])
np.log1p,
np.sqrt,
np.square,
np.cbrt,
np.reciprocal])
@seed_test
def test_float_unary_overloading(setup_test, # noqa: F811
op):
Expand All @@ -123,7 +129,7 @@ def forward(y):

return (c - 1.0) ** 4

if op is np.arccosh:
if op in {np.arccosh, np.reciprocal}:
y = Float(1.1)
else:
y = Float(0.1)
Expand Down Expand Up @@ -165,11 +171,13 @@ def forward(y):
operator.mul,
operator.truediv,
operator.pow,
np.arctan2])
np.arctan2,
np.hypot])
@seed_test
def test_float_binary_overloading(setup_test, # noqa: F811
dtype, op):
if op is np.arctan2 and issubclass(dtype, (complex, np.complexfloating)):
if op in {np.arctan2, np.hypot} \
and issubclass(dtype, (complex, np.complexfloating)):
pytest.skip()
set_default_float_dtype(dtype)

Expand Down
19 changes: 13 additions & 6 deletions tests/base/test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,16 @@ def forward(y):
np.arccosh,
np.arctanh,
np.exp,
np.exp2,
np.expm1,
np.log,
np.log2,
np.log10,
np.sqrt])
np.log1p,
np.sqrt,
np.square,
np.cbrt,
np.reciprocal])
@seed_test
def test_jax_unary_overloading(setup_test, jax_tlm_config, # noqa: F811
op):
Expand All @@ -119,7 +125,7 @@ def forward(y):

return (c - 1.0) ** 4

if op is np.arccosh:
if op in {np.arccosh, np.reciprocal}:
y = np.array([1.1, 1.2], dtype=np.double)
else:
y = np.array([0.1, 0.2], dtype=np.double)
Expand Down Expand Up @@ -160,16 +166,17 @@ def forward(y):
operator.mul,
operator.truediv,
operator.pow,
np.arctan2])
np.arctan2,
np.hypot])
@seed_test
def test_jax_binary_overloading(setup_test, jax_tlm_config, # noqa: F811
dtype, op):
if op in {np.arctan2, np.hypot} \
and issubclass(dtype, (complex, np.complexfloating)):
pytest.skip()
set_default_float_dtype(dtype)
set_default_jax_dtype(dtype)

if op is np.arctan2 and issubclass(dtype, (complex, np.complexfloating)):
pytest.skip()

def forward(y):
x = y * y
x = op(x, y)
Expand Down
4 changes: 4 additions & 0 deletions tlm_adjoint/jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,10 @@ def __init__(self, n, *, dtype=None, comm=None):
if comm is None:
comm = DEFAULT_COMM

if not issubclass(dtype, (float, np.floating,
complex, np.complexfloating)):
raise TypeError("Invalid dtype")

self._n = n
self._dtype = dtype
self._comm = comm_dup_cached(comm)
Expand Down
66 changes: 64 additions & 2 deletions tlm_adjoint/overloaded_float.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ def __init__(self, float_cls=None, *, dtype=None, comm=None):
if comm is None:
comm = DEFAULT_COMM

if not issubclass(dtype, (float, np.floating,
complex, np.complexfloating)):
raise TypeError("Invalid dtype")

self._comm = comm_dup_cached(comm)
self._dtype = dtype
self._float_cls = float_cls
Expand Down Expand Up @@ -376,8 +380,13 @@ def __init__(self, value=0.0, *, name=None, space_type="primal",
else:
self.assign(value, annotate=annotate, tlm=tlm)

def __new__(cls, *args, **kwargs):
return super().__new__(cls, new_symbol_name())
def __new__(cls, *args, dtype=None, **kwargs):
if dtype is None:
dtype = _default_dtype
if issubclass(dtype, (float, np.floating)):
return super().__new__(cls, new_symbol_name(), real=True)
else:
return super().__new__(cls, new_symbol_name(), complex=True)

def new(self, value=0.0, *,
name=None,
Expand Down Expand Up @@ -575,6 +584,24 @@ def fdiff(self, argindex=1):
return sp.exp(self.args[0])


@register_function(np.log1p, "numpy.log1p")
class _tlm_adjoint__log1p(sp.Function): # noqa: N801
def fdiff(self, argindex=1):
if argindex == 1:
return sp.Integer(1) / (sp.Integer(1) + self.args[0])


@register_function(np.hypot, "numpy.hypot")
class _tlm_adjoint__hypot(sp.Function): # noqa: N801
def fdiff(self, argindex=1):
if argindex == 1:
return self.args[0] / _tlm_adjoint__hypot(self.args[0],
self.args[1])
elif argindex == 2:
return self.args[1] / _tlm_adjoint__hypot(self.args[0],
self.args[1])


class _tlm_adjoint__OverloadedFloat(np.lib.mixins.NDArrayOperatorsMixin, # noqa: E501,N801
SymbolicFloat):
"""A subclass of :class:`.SymbolicFloat` with operator overloading.
Expand Down Expand Up @@ -697,6 +724,11 @@ def arctan(x):
def arctan2(x1, x2):
return sp.atan2(x1, x2)

@staticmethod
@register_operation(np.hypot)
def hypot(x1, x2):
return _tlm_adjoint__hypot(x1, x2)

@staticmethod
@register_operation(np.sinh)
def sinh(x):
Expand Down Expand Up @@ -732,6 +764,11 @@ def arctanh(x):
def exp(x):
return sp.exp(x)

@staticmethod
@register_operation(np.exp2)
def exp2(x):
return 2 ** x

@staticmethod
@register_operation(np.expm1)
def expm1(x):
Expand All @@ -742,16 +779,41 @@ def expm1(x):
def log(x):
return sp.log(x)

@staticmethod
@register_operation(np.log2)
def log2(x):
return sp.log(x, 2)

@staticmethod
@register_operation(np.log10)
def log10(x):
return sp.log(x, 10)

@staticmethod
@register_operation(np.log1p)
def log1p(x):
return _tlm_adjoint__log1p(x)

@staticmethod
@register_operation(np.sqrt)
def sqrt(x):
return sp.sqrt(x)

@staticmethod
@register_operation(np.square)
def square(x):
return x ** 2

@staticmethod
@register_operation(np.cbrt)
def cbrt(x):
return x ** sp.Rational(1, 3)

@staticmethod
@register_operation(np.reciprocal)
def reciprocal(x):
return sp.Integer(1) / x


# Required by Sphinx
class OverloadedFloat(_tlm_adjoint__OverloadedFloat):
Expand Down

0 comments on commit 0b8fee1

Please sign in to comment.