From dfafdcc6e01ada995621b12255fb650ff585f3a3 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Wed, 18 Oct 2023 10:05:30 -0700 Subject: [PATCH] adding test_where for unary operations --- cunumeric/deferred.py | 2 ++ tests/integration/test_logical.py | 15 ++++++++++----- tests/integration/test_nan_reduction.py | 19 +++++++++++++++++++ tests/integration/test_prod.py | 14 ++++++++++---- tests/integration/test_reduction.py | 12 +++++++++--- 5 files changed, 50 insertions(+), 12 deletions(-) diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index 02dcbb13f..c1a102e73 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -3159,6 +3159,8 @@ def unary_reduction( ) is_where = bool(where is not None) + if is_where: + where = self.runtime.to_deferred_array(where) # See if we are doing reduction to a point or another region if lhs_array.size == 1: assert axes is None or lhs_array.ndim == rhs_array.ndim - ( diff --git a/tests/integration/test_logical.py b/tests/integration/test_logical.py index ca9b99220..b0f83aaa6 100644 --- a/tests/integration/test_logical.py +++ b/tests/integration/test_logical.py @@ -102,19 +102,24 @@ def test_nd_inputs(ndim, func): assert np.array_equal(out_np, out_num) -@pytest.mark.skip def test_where(): - # "the `where` parameter is currently not supported" - x = np.array([[True, True, False], [True, True, True]]) y = np.array([[True, False], [True, True]]) cy = num.array(y) + # where needs to be broadcasted assert num.array_equal( - num.all(cy, where=[True, False]), np.all(x, where=[True, False]) + num.all(cy, where=[True, False]), np.all(y, where=[True, False]) ) assert num.array_equal( num.any(cy, where=[[True], [False]]), - np.any(x, where=[[True], [False]]), + np.any(y, where=[[True], [False]]), + ) + + # Where is a boolean + assert num.array_equal(num.all(cy, where=True), np.all(y, where=True)) + assert num.array_equal( + num.any(cy, where=False), + np.any(y, where=False), ) diff --git a/tests/integration/test_nan_reduction.py b/tests/integration/test_nan_reduction.py index 34bbd1447..d3b5fd7e9 100644 --- a/tests/integration/test_nan_reduction.py +++ b/tests/integration/test_nan_reduction.py @@ -291,6 +291,25 @@ def test_all_nans_nansum(self, ndim): assert out_num == 0.0 + def test_where(self): + arr = [[1, np.nan, 3], [2, np.nan, 4]] + out_np = np.nansum(arr, where=[False, True, True]) + out_num = num.nansum(arr, where=[False, True, True]) + assert np.allclose(out_np, out_num) + + out_np = np.nanprod(arr, where=[False, True, True]) + out_num = num.nanprod(arr, where=[False, True, True]) + assert np.allclose(out_np, out_num) + + # where is a boolean + out_np = np.nanmax(arr, where=True) + out_num = num.nanmax(arr, where=True) + assert np.allclose(out_np, out_num) + + out_np = np.nanmin(arr, where=True) + out_num = num.nanmin(arr, where=True) + assert np.allclose(out_np, out_num) + class TestCornerCases: """ diff --git a/tests/integration/test_prod.py b/tests/integration/test_prod.py index ab0f4def8..c004c95a3 100644 --- a/tests/integration/test_prod.py +++ b/tests/integration/test_prod.py @@ -147,15 +147,21 @@ def test_initial_empty_array(self): out_np = np.prod(arr_np, initial=initial_value) assert allclose(out_np, out_num) - @pytest.mark.xfail def test_where(self): arr = [[1, 2], [3, 4]] - out_np = np.prod(arr, where=[False, True]) # return 8 - # cuNumeric raises NotImplementedError: - # the `where` parameter is currently not supported + out_np = np.prod(arr, where=[False, True]) out_num = num.prod(arr, where=[False, True]) assert allclose(out_np, out_num) + # where is boolean + out_np = np.prod(arr, where=True) + out_num = num.prod(arr, where=True) + assert allclose(out_np, out_num) + + out_np = np.prod(arr, where=False) + out_num = num.prod(arr, where=False) + assert allclose(out_np, out_num) + class TestProdPositive(object): """ diff --git a/tests/integration/test_reduction.py b/tests/integration/test_reduction.py index a7a89a6af..f3379265b 100644 --- a/tests/integration/test_reduction.py +++ b/tests/integration/test_reduction.py @@ -151,15 +151,21 @@ def test_initial_empty_array(self): out_np = np.sum(arr_np, initial=initial_value) # return initial_value assert allclose(out_np, out_num) - @pytest.mark.xfail def test_where(self): arr = [[1, 2], [3, 4]] out_np = np.sum(arr, where=[False, True]) # return 6 - # cuNumeric raises NotImplementedError: - # "the `where` parameter is currently not supported" out_num = num.sum(arr, where=[False, True]) assert allclose(out_np, out_num) + # where is a boolean + out_np = np.sum(arr, where=True) + out_num = num.sum(arr, where=True) + assert allclose(out_np, out_num) + + out_np = np.sum(arr, where=False) + out_num = num.sum(arr, where=False) + assert allclose(out_np, out_num) + class TestSumPositive(object): """