diff --git a/cunumeric/module.py b/cunumeric/module.py index 0a8132a81..6a0a586ab 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -8282,3 +8282,115 @@ def histogram( hist /= bins_array[1:] - bins_array[:-1] return hist.astype(result_type), bins_array.astype(bins_orig_type) + + +@add_boilerplate("x", "bins") +def digitize( + x: ndarray, + bins: ndarray, + right: bool = False, +) -> Union[int, ndarray]: + """ + Return the indices of the bins to which each value in input array belongs. + + ========= ============= ============================ + `right` order of bins returned index `i` satisfies + ========= ============= ============================ + ``False`` increasing ``bins[i-1] <= x < bins[i]`` + ``True`` increasing ``bins[i-1] < x <= bins[i]`` + ``False`` decreasing ``bins[i-1] > x >= bins[i]`` + ``True`` decreasing ``bins[i-1] >= x > bins[i]`` + ========= ============= ============================ + + If values in `x` are beyond the bounds of `bins`, 0 or ``len(bins)`` is + returned as appropriate. + + Parameters + ---------- + x : array_like + Input array to be binned. Doesn't need to be 1-dimensional. + bins : array_like + Array of bins. It has to be 1-dimensional and monotonic. + right : bool, optional + Indicating whether the intervals include the right or the left bin + edge. Default behavior is (right==False) indicating that the interval + does not include the right edge. The left bin end is open in this + case, i.e., bins[i-1] <= x < bins[i] is the default behavior for + monotonically increasing bins. + + Returns + ------- + indices : ndarray of ints + Output array of indices, of same shape as `x`. + + Raises + ------ + ValueError + If `bins` is not monotonic. + TypeError + If the type of the input is complex. + + See Also + -------- + numpy.digitize + + Notes + ----- + If values in `x` are such that they fall outside the bin range, + attempting to index `bins` with the indices that `digitize` returns + will result in an IndexError. + + For monotonically *increasing* `bins`, the following are equivalent:: + + np.digitize(x, bins, right=True) + np.searchsorted(bins, x, side='left') + + Note that as the order of the arguments are reversed, the side must be too. + The `searchsorted` call is marginally faster, as it does not do any + monotonicity checks. Perhaps more importantly, it supports all dtypes. + + Examples + -------- + >>> x = np.array([0.2, 6.4, 3.0, 1.6]) + >>> bins = np.array([0.0, 1.0, 2.5, 4.0, 10.0]) + >>> inds = np.digitize(x, bins) + >>> inds + array([1, 4, 3, 2]) + >>> for n in range(x.size): + ... print(bins[inds[n]-1], "<=", x[n], "<", bins[inds[n]]) + ... + 0.0 <= 0.2 < 1.0 + 4.0 <= 6.4 < 10.0 + 2.5 <= 3.0 < 4.0 + 1.0 <= 1.6 < 2.5 + + >>> x = np.array([1.2, 10.0, 12.4, 15.5, 20.]) + >>> bins = np.array([0, 5, 10, 15, 20]) + >>> np.digitize(x,bins,right=True) + array([1, 2, 3, 4, 4]) + >>> np.digitize(x,bins,right=False) + array([1, 3, 3, 4, 5]) + + Availability + -------- + Multiple GPUs, Multiple CPUs + """ + # here for compatibility, searchsorted below is happy to take this + if np.issubdtype(x.dtype, np.complexfloating): + raise TypeError("x may not be complex") + + if bins.ndim > 1: + raise ValueError("bins must be one-dimensional") + + increasing = (bins[1:] >= bins[:-1]).all() + decreasing = (bins[1:] <= bins[:-1]).all() + if not increasing and not decreasing: + raise ValueError("bins must be monotonically increasing or decreasing") + + # this is backwards because the arguments below are swapped + side: SortSide = "left" if right else "right" + if decreasing: + # reverse the bins, and invert the results + return len(bins) - searchsorted(bins.flip(), x, side=side) + else: + return searchsorted(bins, x, side=side) diff --git a/tests/integration/test_digitize.py b/tests/integration/test_digitize.py new file mode 100644 index 000000000..9c75c3a60 --- /dev/null +++ b/tests/integration/test_digitize.py @@ -0,0 +1,166 @@ +# Copyright 2022 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import math + +import numpy as np +import pytest + +import cunumeric as num + +DTYPES = ( + np.uint32, + np.uint64, + np.float32, + np.float64, +) + +SHAPES = ( + (10,), + (2, 5), + (3, 7, 10), +) + + +class TestDigitizeErrors(object): + def test_complex_array(self): + a = np.array([2, 3, 10, 9], dtype=np.complex64) + bins = [0, 3, 5] + expected_exc = TypeError + with pytest.raises(expected_exc): + num.digitize(a, bins) + with pytest.raises(expected_exc): + np.digitize(a, bins) + + @pytest.mark.xfail + def test_bad_array(self): + bins = [0, 5, 3] + expected_exc = ValueError + with pytest.raises(expected_exc): + # cunumeric raises TypeError + num.digitize(None, bins) + with pytest.raises(expected_exc): + np.digitize(None, bins) + + @pytest.mark.xfail + def test_bad_bins(self): + a = [2, 3, 10, 9] + expected_exc = ValueError + with pytest.raises(expected_exc): + # cunumeric raises TypeError + num.digitize(a, None) + with pytest.raises(expected_exc): + np.digitize(a, None) + + def test_bins_non_monotonic(self): + a = [2, 3, 10, 9] + bins = [0, 5, 3] + expected_exc = ValueError + with pytest.raises(expected_exc): + num.digitize(a, bins) + with pytest.raises(expected_exc): + np.digitize(a, bins) + + def test_bins_ndim(self): + a = [2, 3, 10, 9] + bins = np.array([[0], [5], [3]]) + expected_exc = ValueError + with pytest.raises(expected_exc): + num.digitize(a, bins) + with pytest.raises(expected_exc): + np.digitize(a, bins) + + +def generate_random(shape, dtype): + a_np = None + size = math.prod(shape) + if np.issubdtype(dtype, np.integer): + a_np = np.array( + np.random.randint( + np.iinfo(dtype).min, + np.iinfo(dtype).max, + size=size, + dtype=dtype, + ), + dtype=dtype, + ) + elif np.issubdtype(dtype, np.floating): + a_np = np.array(np.random.random(size=size), dtype=dtype) + elif np.issubdtype(dtype, np.complexfloating): + a_np = np.array( + np.random.random(size=size) + np.random.random(size=size) * 1j, + dtype=dtype, + ) + else: + assert False + return a_np.reshape(shape) + + +@pytest.mark.parametrize("right", (True, False)) +def test_empty(right): + bins = [0, 3, 5] + assert len(num.digitize([], bins, right=right)) == 0 + + +@pytest.mark.parametrize("shape", SHAPES, ids=str) +@pytest.mark.parametrize("dtype", DTYPES, ids=str) +@pytest.mark.parametrize("right", (True, False)) +def test_increasing_bins(shape, dtype, right): + a = generate_random(shape, dtype) + bins = [0, 3, 5] + + a_num = num.array(a) + bins_num = num.array(bins) + + res_np = np.digitize(a, bins, right=right) + res_num = num.digitize(a, bins, right=right) + assert num.array_equal(res_np, res_num) + + res_np = np.digitize(a, bins, right=right) + res_num = num.digitize(a_num, bins, right=right) + assert num.array_equal(res_np, res_num) + + res_np = np.digitize(a, bins, right=right) + res_num = num.digitize(a_num, bins_num, right=right) + assert num.array_equal(res_np, res_num) + + +@pytest.mark.parametrize("shape", SHAPES, ids=str) +@pytest.mark.parametrize("dtype", DTYPES, ids=str) +@pytest.mark.parametrize("right", (True, False)) +def test_decreasing_bins(shape, dtype, right): + a = generate_random(shape, dtype) + bins = [5, 3, 0] + + a_num = num.array(a) + bins_num = num.array(bins) + + res_np = np.digitize(a, bins, right=right) + res_num = num.digitize(a, bins, right=right) + assert num.array_equal(res_np, res_num) + + res_np = np.digitize(a, bins, right=right) + res_num = num.digitize(a_num, bins, right=right) + assert num.array_equal(res_np, res_num) + + res_np = np.digitize(a, bins, right=right) + res_num = num.digitize(a_num, bins_num, right=right) + assert num.array_equal(res_np, res_num) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(sys.argv))