Skip to content

Commit

Permalink
Fix to account to sort()'s behavior with NaNs on device.
Browse files Browse the repository at this point in the history
  • Loading branch information
aschaffer committed Feb 16, 2024
1 parent 4722bba commit 2097d65
Showing 1 changed file with 24 additions and 17 deletions.
41 changes: 24 additions & 17 deletions cunumeric/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -8281,16 +8281,17 @@ def percentile(
# qs_all: [in/out] result pass through or created (returned)
#
def nanquantile_impl(
arr: np.ndarray,
q_arr: np.ndarray[Any],
arr: ndarray,
q_arr: npt.NDArray[Any],
non_nan_counts: ndarray,
axis: Optional[int],
axes_set: Iterable[int],
original_shape: tuple[int, ...],
method: Callable[[float, int], tuple[float, int]],
keepdims: bool,
to_dtype: np.dtype[Any],
qs_all: Optional[np.ndarray],
) -> np.ndarray:
qs_all: Optional[ndarray],
) -> ndarray:
ndims = len(arr.shape)

if axis is None:
Expand Down Expand Up @@ -8319,7 +8320,7 @@ def nanquantile_impl(
#
qresult_shape = (*q_arr.shape, *remaining_shape)

# construct result Np.Ndarray, non-flattening approach:
# construct result Ndarray, non-flattening approach:
#
if qs_all is None:
qs_all = zeros(qresult_shape, dtype=to_dtype)
Expand All @@ -8329,9 +8330,6 @@ def nanquantile_impl(
if qs_all.shape != qresult_shape:
raise ValueError("wrong shape on output array")

# ndarray of non-NaNs:
#
non_nan_counts = count_nonzero((isnan(arr) == False), axis=axis)
assert non_nan_counts.shape == remaining_shape

arr_ones = ones(remaining_shape, dtype=arr.dtype)
Expand Down Expand Up @@ -8393,14 +8391,14 @@ def nanquantile_impl(

@add_boilerplate("a")
def nanquantile(
a: np.ndarray,
q: Union[float, Iterable[float], np.ndarray],
a: ndarray,
q: Union[float, Iterable[float], ndarray],
axis: Union[None, int, tuple[int, ...]] = None,
out: Optional[np.ndarray] = None,
out: Optional[ndarray] = None,
overwrite_input: bool = False,
method: str = "linear",
keepdims: bool = False,
) -> np.ndarray:
) -> ndarray:
"""
Compute the q-th quantile of the data along the specified axis,
while ignoring nan values.
Expand All @@ -8416,7 +8414,7 @@ def nanquantile(
axis : {int, tuple of int, None}, optional
Axis or axes along which the quantiles are computed. The default is
to compute the quantile(s) along a flattened version of the array.
out : np.ndarray, optional
out : ndarray, optional
Alternative output array in which to place the result. It must have
the same shape as the expected output.
overwrite_input : bool, optional
Expand Down Expand Up @@ -8450,7 +8448,7 @@ def nanquantile(
Returns
-------
quantile : scalar or np.ndarray
quantile : scalar or ndarray
If `q` is a single quantile and `axis=None`, then the result
is a scalar. If multiple quantiles are given, first axis of
the result corresponds to the quantiles. The other axes are
Expand Down Expand Up @@ -8516,6 +8514,14 @@ def nanquantile(
if real_axis is not None:
axes_set = [real_axis]

# ndarray of non-NaNs:
#
non_nan_counts = count_nonzero((isnan(a_rr) == False), axis=real_axis)

# replace NaN's by dtype.max:
#
b_rr = where(isnan(a_rr), np.finfo(a_rr.dtype).max, a_rr)

# covers both array-like and scalar cases:
#
q_arr = np.asarray(q)
Expand All @@ -8528,10 +8534,10 @@ def nanquantile(
# if no axis given then elements are sorted as a 1D array
#
if overwrite_input:
a_rr.sort(axis=real_axis)
arr = a_rr
b_rr.sort(axis=real_axis)
arr = b_rr
else:
arr = np.sort(a_rr, axis=real_axis)
arr = sort(b_rr, axis=real_axis)

if arr.dtype.kind == "c":
raise TypeError("input array cannot be of complex type")
Expand Down Expand Up @@ -8572,6 +8578,7 @@ def nanquantile(
res = nanquantile_impl(
arr,
q_arr,
non_nan_counts,
real_axis,
axes_set,
original_shape,
Expand Down

0 comments on commit 2097d65

Please sign in to comment.