Skip to content

Commit

Permalink
Update jax.scipy.special.gamma and gammasgn to return NaN for negativ…
Browse files Browse the repository at this point in the history
…e integer inputs.

Change to match upstream scipy: scipy/scipy#21827.

Fixes #24875
  • Loading branch information
hawkinsp committed Nov 18, 2024
1 parent f7ae0f9 commit a4552bb
Show file tree
Hide file tree
Showing 10 changed files with 62 additions and 13 deletions.
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,9 @@ When releasing, please add the new-release-boilerplate to docs/pallas/CHANGELOG.
* {func}`jax.scipy.linalg.toeplitz` now does implicit batching on multi-dimensional
inputs. To recover the previous behavior, you can call {func}`jax.numpy.ravel`
on the function inputs.
* {func}`jax.scipy.special.gamma` and {func}`jax.scipy.special.gammasgn` now
return NaN for negative integer inputs, to match the behavior of SciPy from
https://github.com/scipy/scipy/pull/21827.
* `jax.clear_backends` was removed after being deprecated in v0.4.26.

* New Features
Expand Down
3 changes: 2 additions & 1 deletion build/requirements_lock_3_10.txt
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ markdown-it-py==3.0.0 \
--hash=sha256:355216845c60bd96232cd8d8c40e8f9765cc86f46880e43a8fd22dc1a1a8cab1 \
--hash=sha256:e3f60a94fa066dc52ec76661e37c851cb232d92f9886b15cb560aaada2df8feb
# via rich
matplotlib==3.8.4 ; python_version <= "3.10" \
matplotlib==3.8.4 ; python_version == "3.10" \
--hash=sha256:1c13f041a7178f9780fb61cc3a2b10423d5e125480e4be51beaf62b172413b67 \
--hash=sha256:232ce322bfd020a434caaffbd9a95333f7c2491e59cfc014041d95e38ab90d1c \
--hash=sha256:493e9f6aa5819156b58fce42b296ea31969f2aab71c5b680b4ea7a3cb5c07d94 \
Expand Down Expand Up @@ -390,6 +390,7 @@ packaging==24.0 \
--hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \
--hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9
# via
# -r build/test-requirements.txt
# auditwheel
# build
# matplotlib
Expand Down
1 change: 1 addition & 0 deletions build/requirements_lock_3_11.txt
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ packaging==24.0 \
--hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \
--hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9
# via
# -r build/test-requirements.txt
# auditwheel
# build
# matplotlib
Expand Down
1 change: 1 addition & 0 deletions build/requirements_lock_3_12.txt
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,7 @@ packaging==24.0 \
--hash=sha256:2ddfb553fdf02fb784c234c7ba6ccc288296ceabec964ad2eae3777778130bc5 \
--hash=sha256:eb82c5e3e56209074766e6885bb04b8c38a0c015d0a30036ebe7ece34c9989e9
# via
# -r build/test-requirements.txt
# auditwheel
# build
# matplotlib
Expand Down
1 change: 1 addition & 0 deletions build/requirements_lock_3_13.txt
Original file line number Diff line number Diff line change
Expand Up @@ -444,6 +444,7 @@ packaging==24.1 \
--hash=sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002 \
--hash=sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124
# via
# -r build/test-requirements.txt
# auditwheel
# build
# matplotlib
Expand Down
1 change: 1 addition & 0 deletions build/test-requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -19,3 +19,4 @@ matplotlib~=3.8.4; python_version=="3.10"
matplotlib; python_version>="3.11"
opt-einsum
auditwheel
packaging
26 changes: 24 additions & 2 deletions jax/_src/scipy/special.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ def gammaln(x: ArrayLike) -> Array:
return lax.lgamma(x)


@jit
def gammasgn(x: ArrayLike) -> Array:
r"""Sign of the gamma function.
Expand All @@ -81,6 +82,13 @@ def gammasgn(x: ArrayLike) -> Array:
Where :math:`\Gamma` is the :func:`~jax.scipy.special.gamma` function.
Because :math:`\Gamma(x)` is never zero, no condition is required for this case.
* if :math:`x = -\infty`, NaN is returned.
* if :math:`x = \pm 0`, :math:`\pm 1` is returned.
* if :math:`x` is a negative integer, NaN is returned. The sign of gamma
at a negative integer depends on from which side the pole is approached.
* if :math:`x = \infty`, :math:`1` is returned.
* if :math:`x` is NaN, NaN is returned.
Args:
x: arraylike, real valued.
Expand All @@ -92,8 +100,14 @@ def gammasgn(x: ArrayLike) -> Array:
- :func:`jax.scipy.special.gammaln`: the natural log of the gamma function
"""
x, = promote_args_inexact("gammasgn", x)
typ = x.dtype.type
floor_x = lax.floor(x)
return jnp.where((x > 0) | (x == floor_x) | (floor_x % 2 == 0), 1.0, -1.0)
x_negative = x < 0
return jnp.select(
[(x_negative & (x == floor_x)) | jnp.isnan(x),
(x_negative & (floor_x % 2 != 0)) | ((x == 0) & jnp.signbit(x))],
[typ(np.nan), typ(-1.0)],
typ(1.0))


def gamma(x: ArrayLike) -> Array:
Expand All @@ -115,6 +129,13 @@ def gamma(x: ArrayLike) -> Array:
\Gamma(n) = (n - 1)!
* if :math:`z = -\infty`, NaN is returned.
* if :math:`x = \pm 0`, :math:`\pm \infty` is returned.
* if :math:`x` is a negative integer, NaN is returned. The sign of gamma
at a negative integer depends on from which side the pole is approached.
* if :math:`x = \infty`, :math:`\infty` is returned.
* if :math:`x` is NaN, NaN is returned.
Args:
x: arraylike, real valued.
Expand All @@ -127,7 +148,8 @@ def gamma(x: ArrayLike) -> Array:
- :func:`jax.scipy.special.gammasgn`: the sign of the gamma function
Notes:
Unlike the scipy version, JAX's ``gamma`` does not support complex-valued inputs.
Unlike the scipy version, JAX's ``gamma`` does not support complex-valued
inputs.
"""
x, = promote_args_inexact("gamma", x)
return gammasgn(x) * lax.exp(lax.lgamma(x))
Expand Down
1 change: 1 addition & 0 deletions jaxlib/jax.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ _py_deps = {
"matplotlib": ["@pypi_matplotlib//:pkg"],
"mpmath": [],
"opt_einsum": ["@pypi_opt_einsum//:pkg"],
"packaging": ["@pypi_packaging//:pkg"],
"pil": ["@pypi_pillow//:pkg"],
"portpicker": ["@pypi_portpicker//:pkg"],
"ml_dtypes": ["@pypi_ml_dtypes//:pkg"],
Expand Down
2 changes: 1 addition & 1 deletion tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -546,7 +546,7 @@ jax_multiplatform_test(
"gpu": 20,
"tpu": 20,
},
deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing"),
deps = py_deps("numpy") + py_deps("scipy") + py_deps("absl/testing") + py_deps("packaging"),
)

jax_multiplatform_test(
Expand Down
36 changes: 27 additions & 9 deletions tests/lax_scipy_special_functions_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,14 @@

from absl.testing import absltest
from absl.testing import parameterized
from packaging.version import Version

import numpy as np
import scipy
import scipy.special as osp_special

import jax
import jax.numpy as jnp
from jax._src import test_util as jtu
from jax.scipy import special as lsp_special

Expand Down Expand Up @@ -214,32 +217,47 @@ def partial_lax_op(*vals):
n=[0, 1, 2, 3, 10, 50]
)
def testScipySpecialFunBernoulli(self, n):
dtype = jax.numpy.zeros(0).dtype # default float dtype.
dtype = jnp.zeros(0).dtype # default float dtype.
scipy_op = lambda: osp_special.bernoulli(n).astype(dtype)
lax_op = functools.partial(lsp_special.bernoulli, n)
args_maker = lambda: []
self._CheckAgainstNumpy(scipy_op, lax_op, args_maker, atol=0, rtol=1E-5)
self._CompileAndCheck(lax_op, args_maker, atol=0, rtol=1E-5)

def testGammaSign(self):
# Test that the sign of `gamma` matches at integer-valued inputs.
dtype = jax.numpy.zeros(0).dtype # default float dtype.
args_maker = lambda: [np.arange(-10, 10).astype(dtype)]
rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5
self._CheckAgainstNumpy(osp_special.gamma, lsp_special.gamma, args_maker, rtol=rtol)
self._CompileAndCheck(lsp_special.gamma, args_maker, rtol=rtol)
dtype = jnp.zeros(0).dtype # default float dtype.
typ = dtype.type
testcases = [
(np.arange(-10, 0).astype(dtype), np.array([np.nan] * 10, dtype=dtype)),
(np.nextafter(np.arange(-5, 0).astype(dtype), typ(-np.inf)),
np.array([1, -1, 1, -1, 1], dtype=dtype)),
(np.nextafter(np.arange(-5, 0).astype(dtype), typ(np.inf)),
np.array([-1, 1, -1, 1, -1], dtype=dtype)),
(np.arange(0, 10).astype(dtype), np.ones((10,), dtype)),
(np.nextafter(np.arange(0, 10).astype(dtype), typ(np.inf)),
np.ones((10,), dtype)),
(np.nextafter(np.arange(1, 10).astype(dtype), typ(-np.inf)),
np.ones((9,), dtype)),
(np.array([-np.inf, -0.0, 0.0, np.inf, np.nan]),
np.array([np.nan, -1.0, 1.0, 1.0, np.nan]))
]
for inp, out in testcases:
self.assertArraysEqual(out, lsp_special.gammasgn(inp))
self.assertArraysEqual(out, jnp.sign(lsp_special.gamma(inp)))
if Version(scipy.__version__) >= Version("1.15"):
self.assertArraysEqual(out, osp_special.gammasgn(inp))

def testNdtriExtremeValues(self):
# Testing at the extreme values (bounds (0. and 1.) and outside the bounds).
dtype = jax.numpy.zeros(0).dtype # default float dtype.
dtype = jnp.zeros(0).dtype # default float dtype.
args_maker = lambda: [np.arange(-10, 10).astype(dtype)]
rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5
self._CheckAgainstNumpy(osp_special.ndtri, lsp_special.ndtri, args_maker, rtol=rtol)
self._CompileAndCheck(lsp_special.ndtri, args_maker, rtol=rtol)

def testRelEntrExtremeValues(self):
# Testing at the extreme values (bounds (0. and 1.) and outside the bounds).
dtype = jax.numpy.zeros(0).dtype # default float dtype.
dtype = jnp.zeros(0).dtype # default float dtype.
args_maker = lambda: [np.array([-2, -2, -2, -1, -1, -1, 0, 0, 0]).astype(dtype),
np.array([-1, 0, 1, -1, 0, 1, -1, 0, 1]).astype(dtype)]
rtol = 1E-3 if jtu.test_device_matches(["tpu"]) else 1e-5
Expand Down

0 comments on commit a4552bb

Please sign in to comment.