Skip to content

Commit

Permalink
adding test_where for unary operations
Browse files Browse the repository at this point in the history
  • Loading branch information
ipdemes committed Oct 18, 2023
1 parent 73e0609 commit dfafdcc
Show file tree
Hide file tree
Showing 5 changed files with 50 additions and 12 deletions.
2 changes: 2 additions & 0 deletions cunumeric/deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -3159,6 +3159,8 @@ def unary_reduction(
)

is_where = bool(where is not None)
if is_where:
where = self.runtime.to_deferred_array(where)
# See if we are doing reduction to a point or another region
if lhs_array.size == 1:
assert axes is None or lhs_array.ndim == rhs_array.ndim - (
Expand Down
15 changes: 10 additions & 5 deletions tests/integration/test_logical.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,19 +102,24 @@ def test_nd_inputs(ndim, func):
assert np.array_equal(out_np, out_num)


@pytest.mark.skip
def test_where():
# "the `where` parameter is currently not supported"
x = np.array([[True, True, False], [True, True, True]])
y = np.array([[True, False], [True, True]])
cy = num.array(y)

# where needs to be broadcasted
assert num.array_equal(
num.all(cy, where=[True, False]), np.all(x, where=[True, False])
num.all(cy, where=[True, False]), np.all(y, where=[True, False])
)
assert num.array_equal(
num.any(cy, where=[[True], [False]]),
np.any(x, where=[[True], [False]]),
np.any(y, where=[[True], [False]]),
)

# Where is a boolean
assert num.array_equal(num.all(cy, where=True), np.all(y, where=True))
assert num.array_equal(
num.any(cy, where=False),
np.any(y, where=False),
)


Expand Down
19 changes: 19 additions & 0 deletions tests/integration/test_nan_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,25 @@ def test_all_nans_nansum(self, ndim):

assert out_num == 0.0

def test_where(self):
arr = [[1, np.nan, 3], [2, np.nan, 4]]
out_np = np.nansum(arr, where=[False, True, True])
out_num = num.nansum(arr, where=[False, True, True])
assert np.allclose(out_np, out_num)

out_np = np.nanprod(arr, where=[False, True, True])
out_num = num.nanprod(arr, where=[False, True, True])
assert np.allclose(out_np, out_num)

# where is a boolean
out_np = np.nanmax(arr, where=True)
out_num = num.nanmax(arr, where=True)
assert np.allclose(out_np, out_num)

out_np = np.nanmin(arr, where=True)
out_num = num.nanmin(arr, where=True)
assert np.allclose(out_np, out_num)


class TestCornerCases:
"""
Expand Down
14 changes: 10 additions & 4 deletions tests/integration/test_prod.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,15 +147,21 @@ def test_initial_empty_array(self):
out_np = np.prod(arr_np, initial=initial_value)
assert allclose(out_np, out_num)

@pytest.mark.xfail
def test_where(self):
arr = [[1, 2], [3, 4]]
out_np = np.prod(arr, where=[False, True]) # return 8
# cuNumeric raises NotImplementedError:
# the `where` parameter is currently not supported
out_np = np.prod(arr, where=[False, True])
out_num = num.prod(arr, where=[False, True])
assert allclose(out_np, out_num)

# where is boolean
out_np = np.prod(arr, where=True)
out_num = num.prod(arr, where=True)
assert allclose(out_np, out_num)

out_np = np.prod(arr, where=False)
out_num = num.prod(arr, where=False)
assert allclose(out_np, out_num)


class TestProdPositive(object):
"""
Expand Down
12 changes: 9 additions & 3 deletions tests/integration/test_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,15 +151,21 @@ def test_initial_empty_array(self):
out_np = np.sum(arr_np, initial=initial_value) # return initial_value
assert allclose(out_np, out_num)

@pytest.mark.xfail
def test_where(self):
arr = [[1, 2], [3, 4]]
out_np = np.sum(arr, where=[False, True]) # return 6
# cuNumeric raises NotImplementedError:
# "the `where` parameter is currently not supported"
out_num = num.sum(arr, where=[False, True])
assert allclose(out_np, out_num)

# where is a boolean
out_np = np.sum(arr, where=True)
out_num = num.sum(arr, where=True)
assert allclose(out_np, out_num)

out_np = np.sum(arr, where=False)
out_num = num.sum(arr, where=False)
assert allclose(out_np, out_num)


class TestSumPositive(object):
"""
Expand Down

0 comments on commit dfafdcc

Please sign in to comment.