Skip to content

Commit

Permalink
Polish math-tools overloads
Browse files Browse the repository at this point in the history
  • Loading branch information
NeilGirdhar committed Sep 14, 2024
1 parent 87dec3a commit 09ad5b3
Showing 1 changed file with 40 additions and 62 deletions.
102 changes: 40 additions & 62 deletions tjax/_src/math_tools.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,20 @@
from __future__ import annotations

from typing import TypeVar, cast, overload
from typing import TypeVar, overload

import jax
import numpy as np
from array_api_compat import get_namespace

from .annotations import (BooleanArray, ComplexArray, IntegralArray, JaxArray, JaxComplexArray,
JaxRealArray, RealArray)
from .annotations import (BooleanArray, ComplexArray, IntegralArray, JaxBooleanArray,
JaxComplexArray, JaxIntegralArray, JaxRealArray, NumpyBooleanArray,
NumpyComplexArray, NumpyIntegralArray, NumpyRealArray, RealArray)


@overload
def abs_square(x: JaxComplexArray) -> JaxRealArray:
...


def abs_square(x: JaxComplexArray) -> JaxRealArray: ...
@overload
def abs_square(x: ComplexArray) -> RealArray:
...


def abs_square(x: ComplexArray) -> RealArray: ...
def abs_square(x: ComplexArray) -> RealArray:
xp = get_namespace(x)
# TODO: remove workaround when Jax is 0.4.27.
Expand All @@ -27,15 +23,9 @@ def abs_square(x: ComplexArray) -> RealArray:

# TODO: Remove this when the Array API has it with broadcasting under xp.linalg.norm.
@overload
def outer_product(x: JaxRealArray, y: JaxRealArray) -> JaxRealArray:
...


def outer_product(x: JaxRealArray, y: JaxRealArray) -> JaxRealArray: ...
@overload
def outer_product(x: RealArray, y: RealArray) -> RealArray:
...


def outer_product(x: RealArray, y: RealArray) -> RealArray: ...
def outer_product(x: RealArray, y: RealArray) -> RealArray:
"""Return the broadcasted outer product of a vector with itself.
Expand All @@ -48,15 +38,9 @@ def outer_product(x: RealArray, y: RealArray) -> RealArray:


@overload
def matrix_vector_mul(x: JaxRealArray, y: JaxRealArray) -> JaxRealArray:
...


def matrix_vector_mul(x: JaxRealArray, y: JaxRealArray) -> JaxRealArray: ...
@overload
def matrix_vector_mul(x: RealArray, y: RealArray) -> RealArray:
...


def matrix_vector_mul(x: RealArray, y: RealArray) -> RealArray: ...
def matrix_vector_mul(x: RealArray, y: RealArray) -> RealArray:
"""Return the matrix-vector product.
Expand All @@ -72,15 +56,9 @@ def matrix_vector_mul(x: RealArray, y: RealArray) -> RealArray:


@overload
def matrix_dot_product(x: JaxRealArray, y: JaxRealArray) -> JaxRealArray:
...


def matrix_dot_product(x: JaxRealArray, y: JaxRealArray) -> JaxRealArray: ...
@overload
def matrix_dot_product(x: RealArray, y: RealArray) -> RealArray:
...


def matrix_dot_product(x: RealArray, y: RealArray) -> RealArray: ...
def matrix_dot_product(x: RealArray, y: RealArray) -> RealArray:
"""Return the "matrix dot product" of a matrix with the outer product of a vector.
Expand All @@ -93,6 +71,24 @@ def matrix_dot_product(x: RealArray, y: RealArray) -> RealArray:
return xp.sum(x * y, axis=(-2, -1))


@overload
def divide_where(dividend: JaxRealArray,
divisor: JaxRealArray | JaxIntegralArray,
*,
where: JaxBooleanArray | None = None,
otherwise: JaxRealArray | None = None) -> JaxRealArray: ...
@overload
def divide_where(dividend: NumpyRealArray,
divisor: NumpyRealArray | NumpyIntegralArray,
*,
where: NumpyBooleanArray | None = None,
otherwise: NumpyRealArray | None = None) -> NumpyRealArray: ...
@overload
def divide_where(dividend: NumpyComplexArray,
divisor: NumpyComplexArray | NumpyIntegralArray,
*,
where: NumpyBooleanArray | None = None,
otherwise: NumpyComplexArray | None = None) -> NumpyComplexArray: ...
def divide_where(dividend: ComplexArray,
divisor: ComplexArray | IntegralArray,
*,
Expand All @@ -116,52 +112,34 @@ def divide_where(dividend: ComplexArray,


@overload
def divide_nonnegative(dividend: JaxRealArray, divisor: JaxRealArray) -> JaxRealArray:
...


def divide_nonnegative(dividend: JaxRealArray, divisor: JaxRealArray) -> JaxRealArray: ...
@overload
def divide_nonnegative(dividend: RealArray, divisor: RealArray) -> RealArray:
...


def divide_nonnegative(dividend: NumpyRealArray, divisor: NumpyRealArray) -> NumpyRealArray: ...
def divide_nonnegative(dividend: RealArray, divisor: RealArray) -> RealArray:
"""Quotient for use with positive reals that never returns NaN.
Returns: The quotient assuming that the dividend and divisor are nonnegative, and infinite
whenever the divisor equals zero.
"""
xp = get_namespace(dividend, divisor)
return cast(RealArray, divide_where(dividend, divisor, where=divisor > 0.0,
otherwise=xp.asarray(xp.inf)))
return divide_where(dividend, divisor, where=divisor > 0.0, # pyright: ignore
otherwise=xp.asarray(xp.inf))


# Remove when https://github.com/scipy/scipy/pull/18605 is released.
@overload
def softplus(x: JaxRealArray) -> JaxRealArray:
...


def softplus(x: JaxRealArray) -> JaxRealArray: ...
@overload
def softplus(x: RealArray) -> RealArray:
...


def softplus(x: RealArray) -> RealArray: ...
def softplus(x: RealArray) -> RealArray:
xp = get_namespace(x)
return xp.logaddexp(xp.asarray(0.0), x)


@overload
def inverse_softplus(y: JaxRealArray) -> JaxRealArray:
...


def inverse_softplus(y: JaxRealArray) -> JaxRealArray: ...
@overload
def inverse_softplus(y: RealArray) -> RealArray:
...


def inverse_softplus(y: RealArray) -> RealArray: ...
def inverse_softplus(y: RealArray) -> RealArray:
xp = get_namespace(y)
return xp.where(y > 80.0, # noqa: PLR2004
Expand All @@ -187,7 +165,7 @@ def create_diagonal_array(m: T) -> T:
for index in np.ndindex(*pre):
target_index = (*index, slice(None, None, n + 1))
source_values = m[*index, :] # type: ignore[arg-type]
if isinstance(retval, JaxArray):
if isinstance(retval, jax.Array):
retval.at[target_index].set(source_values)
else:
retval[target_index] = source_values
Expand Down

0 comments on commit 09ad5b3

Please sign in to comment.