Skip to content

Commit

Permalink
cunumeric: addressing review comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohan Yadav committed Feb 13, 2024
1 parent 01ef150 commit a99f44e
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 23 deletions.
24 changes: 9 additions & 15 deletions cunumeric/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -6242,13 +6242,13 @@ def nansum(
# Arithmetic operations


@add_boilerplate("a")
@add_boilerplate("a", "prepend", "append")
def diff(
a: ndarray,
n: int = 1,
axis: int = -1,
prepend: Any = None,
append: Any = None,
prepend: ndarray | None = None,
append: ndarray | None = None,
) -> ndarray:
"""
Calculate the n-th discrete difference along the given axis.
Expand Down Expand Up @@ -6278,11 +6278,10 @@ def diff(
except along `axis` where the dimension is smaller by `n`. The
type of the output is the same as the type of the difference
between any two elements of `a`. This is the same as the type of
`a` in most cases. A notable exception is `datetime64`, which
results in a `timedelta64` output array.
`a` in most cases.
See Also
--------
gradient, ediff1d, cumsum
numpy.diff
Notes
-----
Type is preserved for boolean arrays, so the result will contain
Expand Down Expand Up @@ -6314,9 +6313,6 @@ def diff(
[5, 1, 2]])
>>> np.diff(x, axis=0)
array([[-1, 2, 0, -2]])
>>> x = np.arange('1066-10-13', '1066-10-16', dtype=np.datetime64)
>>> np.diff(x)
array([1, 1], dtype='timedelta64[D]')
Availability
--------
Expand All @@ -6336,33 +6332,31 @@ def diff(

combined = []
if prepend is not None:
prepend = np.asanyarray(prepend)
if prepend.ndim == 0:
shape = list(a.shape)
shape[axis] = 1
prepend = np.broadcast_to(prepend, tuple(shape))
prepend = broadcast_to(prepend, tuple(shape))
combined.append(prepend)

combined.append(a)

if append is not None:
append = np.asanyarray(append)
if append.ndim == 0:
shape = list(a.shape)
shape[axis] = 1
append = np.broadcast_to(append, tuple(shape))
append = broadcast_to(append, tuple(shape))
combined.append(append)

if len(combined) > 1:
a = np.concatenate(combined, axis)
a = concatenate(combined, axis)

# Diffing with n > shape results in an empty array. We have
# to handle this case explicitly as our slicing routines raise
# an exception with out-of-bounds slices, while NumPy's dont.
if a.shape[axis] <= n:
shape = list(a.shape)
shape[axis] = 0
return np.empty(shape=shape, dtype=a.dtype)
return empty(shape=shape, dtype=a.dtype)

slice1l = [slice(None)] * nd
slice2l = [slice(None)] * nd
Expand Down
16 changes: 8 additions & 8 deletions tests/integration/test_diff.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest
from utils.comparisons import allclose

import cunumeric as cn
import cunumeric as num


@pytest.mark.parametrize(
Expand All @@ -41,22 +41,22 @@
)
def test_diff(args):
shape, n, axis, prepend, append = args
num = np.random.random(shape)
cun = cn.array(num)
nparr = np.random.random(shape)
cnarr = num.array(nparr)

# We are not adopting the np._NoValue default arguments
# for this function, as no special behavior is needed on None.
n_prepend = np._NoValue if prepend is None else prepend
n_append = np._NoValue if append is None else append
res_num = np.diff(num, n=n, axis=axis, prepend=n_prepend, append=n_append)
res_cn = cn.diff(cun, n=n, axis=axis, prepend=prepend, append=append)
res_np = np.diff(nparr, n=n, axis=axis, prepend=n_prepend, append=n_append)
res_cn = num.diff(cnarr, n=n, axis=axis, prepend=prepend, append=append)

assert allclose(res_num, res_cn)
assert allclose(res_np, res_cn)


def test_diff_nzero():
a = cn.ones(100)
ad = cn.diff(a, n=0)
a = num.ones(100)
ad = num.diff(a, n=0)
assert a is ad


Expand Down

0 comments on commit a99f44e

Please sign in to comment.