From 2097d654e91aed291d1b451c7347df95aa64fc82 Mon Sep 17 00:00:00 2001 From: Andrei Schaffer Date: Thu, 15 Feb 2024 18:16:41 -0600 Subject: [PATCH] Fix to account to sort()'s behavior with NaNs on device. --- cunumeric/module.py | 41 ++++++++++++++++++++++++----------------- 1 file changed, 24 insertions(+), 17 deletions(-) diff --git a/cunumeric/module.py b/cunumeric/module.py index 746906ab8..937765320 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -8281,16 +8281,17 @@ def percentile( # qs_all: [in/out] result pass through or created (returned) # def nanquantile_impl( - arr: np.ndarray, - q_arr: np.ndarray[Any], + arr: ndarray, + q_arr: npt.NDArray[Any], + non_nan_counts: ndarray, axis: Optional[int], axes_set: Iterable[int], original_shape: tuple[int, ...], method: Callable[[float, int], tuple[float, int]], keepdims: bool, to_dtype: np.dtype[Any], - qs_all: Optional[np.ndarray], -) -> np.ndarray: + qs_all: Optional[ndarray], +) -> ndarray: ndims = len(arr.shape) if axis is None: @@ -8319,7 +8320,7 @@ def nanquantile_impl( # qresult_shape = (*q_arr.shape, *remaining_shape) - # construct result Np.Ndarray, non-flattening approach: + # construct result Ndarray, non-flattening approach: # if qs_all is None: qs_all = zeros(qresult_shape, dtype=to_dtype) @@ -8329,9 +8330,6 @@ def nanquantile_impl( if qs_all.shape != qresult_shape: raise ValueError("wrong shape on output array") - # ndarray of non-NaNs: - # - non_nan_counts = count_nonzero((isnan(arr) == False), axis=axis) assert non_nan_counts.shape == remaining_shape arr_ones = ones(remaining_shape, dtype=arr.dtype) @@ -8393,14 +8391,14 @@ def nanquantile_impl( @add_boilerplate("a") def nanquantile( - a: np.ndarray, - q: Union[float, Iterable[float], np.ndarray], + a: ndarray, + q: Union[float, Iterable[float], ndarray], axis: Union[None, int, tuple[int, ...]] = None, - out: Optional[np.ndarray] = None, + out: Optional[ndarray] = None, overwrite_input: bool = False, method: str = "linear", keepdims: bool = False, -) -> np.ndarray: +) -> ndarray: """ Compute the q-th quantile of the data along the specified axis, while ignoring nan values. @@ -8416,7 +8414,7 @@ def nanquantile( axis : {int, tuple of int, None}, optional Axis or axes along which the quantiles are computed. The default is to compute the quantile(s) along a flattened version of the array. - out : np.ndarray, optional + out : ndarray, optional Alternative output array in which to place the result. It must have the same shape as the expected output. overwrite_input : bool, optional @@ -8450,7 +8448,7 @@ def nanquantile( Returns ------- - quantile : scalar or np.ndarray + quantile : scalar or ndarray If `q` is a single quantile and `axis=None`, then the result is a scalar. If multiple quantiles are given, first axis of the result corresponds to the quantiles. The other axes are @@ -8516,6 +8514,14 @@ def nanquantile( if real_axis is not None: axes_set = [real_axis] + # ndarray of non-NaNs: + # + non_nan_counts = count_nonzero((isnan(a_rr) == False), axis=real_axis) + + # replace NaN's by dtype.max: + # + b_rr = where(isnan(a_rr), np.finfo(a_rr.dtype).max, a_rr) + # covers both array-like and scalar cases: # q_arr = np.asarray(q) @@ -8528,10 +8534,10 @@ def nanquantile( # if no axis given then elements are sorted as a 1D array # if overwrite_input: - a_rr.sort(axis=real_axis) - arr = a_rr + b_rr.sort(axis=real_axis) + arr = b_rr else: - arr = np.sort(a_rr, axis=real_axis) + arr = sort(b_rr, axis=real_axis) if arr.dtype.kind == "c": raise TypeError("input array cannot be of complex type") @@ -8572,6 +8578,7 @@ def nanquantile( res = nanquantile_impl( arr, q_arr, + non_nan_counts, real_axis, axes_set, original_shape,