diff --git a/cunumeric/_ufunc/ufunc.py b/cunumeric/_ufunc/ufunc.py index 11800e53f..3bba6b8ad 100644 --- a/cunumeric/_ufunc/ufunc.py +++ b/cunumeric/_ufunc/ufunc.py @@ -1,4 +1,4 @@ -# Copyright 2021-2022 NVIDIA Corporation +# Copyright 2021-2023 NVIDIA Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -14,12 +14,17 @@ # from __future__ import annotations -from typing import TYPE_CHECKING, Any, Dict, Sequence, Union +from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union import numpy as np from legate.core.utils import OrderedSet -from ..array import check_writeable, convert_to_cunumeric_ndarray, ndarray +from ..array import ( + add_boilerplate, + check_writeable, + convert_to_cunumeric_ndarray, + ndarray, +) from ..config import BinaryOpCode, UnaryOpCode, UnaryRedCode from ..types import NdShape @@ -680,6 +685,7 @@ def __call__( return self._maybe_cast_output(out, result) + @add_boilerplate("array") def reduce( self, array: ndarray, @@ -688,7 +694,7 @@ def reduce( out: Union[ndarray, None] = None, keepdims: bool = False, initial: Union[Any, None] = None, - where: bool = True, + where: Optional[ndarray] = None, ) -> ndarray: """ reduce(array, axis=0, dtype=None, out=None, keepdims=False, initial= R: for k, v in kwargs.items(): if v is None: continue - elif k == "where": - kwargs[k] = convert_to_predicate_ndarray(v) elif k == "out": kwargs[k] = convert_to_cunumeric_ndarray(v, share=True) if not kwargs[k].flags.writeable: raise ValueError("out is not writeable") - elif k in keys: + elif (k in keys) or (k == "where"): kwargs[k] = convert_to_cunumeric_ndarray(v) return func(*args, **kwargs) @@ -164,16 +161,6 @@ def convert_to_cunumeric_ndarray(obj: Any, share: bool = False) -> ndarray: return ndarray(shape=None, thunk=thunk, writeable=writeable) -def convert_to_predicate_ndarray(obj: Any) -> bool: - # Keep all boolean types as they are - if obj is True or obj is False: - return obj - # GH #135 - raise NotImplementedError( - "the `where` parameter is currently not supported" - ) - - def maybe_convert_to_np_ndarray(obj: Any) -> Any: """ Converts cuNumeric arrays into NumPy arrays, otherwise has no effect. @@ -198,6 +185,16 @@ def check_writeable(arr: Union[ndarray, tuple[ndarray, ...], None]) -> None: raise ValueError("array is not writeable") +def broadcast_where( + where: Union[ndarray, None], shape: NdShape +) -> Union[ndarray, None]: + if where is not None and where.shape != shape: + from .module import broadcast_to + + where = broadcast_to(where, shape) + return where + + class flagsobj: """ Information about the memory layout of the array. @@ -1767,7 +1764,7 @@ def all( out: Union[ndarray, None] = None, keepdims: bool = False, initial: Union[int, float, None] = None, - where: Union[bool, ndarray] = True, + where: Union[ndarray, None] = None, ) -> ndarray: """a.all(axis=None, out=None, keepdims=False, initial=None, where=True) @@ -1802,7 +1799,7 @@ def any( out: Union[ndarray, None] = None, keepdims: bool = False, initial: Union[int, float, None] = None, - where: Union[bool, ndarray] = True, + where: Union[ndarray, None] = None, ) -> ndarray: """a.any(axis=None, out=None, keepdims=False, initial=None, where=True) @@ -3044,7 +3041,7 @@ def max( out: Union[ndarray, None] = None, keepdims: bool = False, initial: Union[int, float, None] = None, - where: Union[bool, ndarray] = True, + where: Union[ndarray, None] = None, ) -> ndarray: """a.max(axis=None, out=None, keepdims=False, initial=, where=True) @@ -3072,6 +3069,16 @@ def max( where=where, ) + def _count_nonzero(self, axis: Any = None) -> Union[int, ndarray]: + if self.size == 0: + return 0 + return ndarray._perform_unary_reduction( + UnaryRedCode.COUNT_NONZERO, + self, + res_dtype=np.dtype(np.uint64), + axis=axis, + ) + def _summation_dtype( self, dtype: Optional[np.dtype[Any]] ) -> np.dtype[Any]: @@ -3084,21 +3091,42 @@ def _summation_dtype( return dtype def _normalize_summation( - self, sum_array: Any, axis: Any, dtype: np.dtype[Any], ddof: int = 0 + self, + sum_array: Any, + axis: Any, + ddof: int = 0, + keepdims: bool = False, + where: Union[ndarray, None] = None, ) -> None: + dtype = sum_array.dtype if axis is None: - divisor = reduce(lambda x, y: x * y, self.shape, 1) - ddof + if where is not None: + divisor = where._count_nonzero() - ddof + else: + divisor = reduce(lambda x, y: x * y, self.shape, 1) - ddof else: - divisor = self.shape[axis] - ddof + if where is not None: + divisor = where.sum(axis=axis, dtype=dtype, keepdims=keepdims) + if ddof != 0 and not np.isscalar(divisor): + mask = divisor != 0 + values = divisor - ddof + divisor._thunk.putmask(mask._thunk, values._thunk) + else: + divisor -= ddof + else: + divisor = self.shape[axis] - ddof # Divide by the number of things in the collapsed dimensions # Pick the right kinds of division based on the dtype + if isinstance(divisor, ndarray): + divisor = divisor.astype(dtype) + else: + divisor = np.array(divisor, dtype=dtype) # type: ignore [assignment] # noqa + if dtype.kind == "f" or dtype.kind == "c": - sum_array.__itruediv__( - np.array(divisor, dtype=sum_array.dtype), - ) + sum_array.__itruediv__(divisor) else: - sum_array.__ifloordiv__(np.array(divisor, dtype=sum_array.dtype)) + sum_array.__ifloordiv__(divisor) @add_boilerplate() def mean( @@ -3107,6 +3135,7 @@ def mean( dtype: Optional[np.dtype[Any]] = None, out: Optional[ndarray] = None, keepdims: bool = False, + where: Union[ndarray, None] = None, ) -> ndarray: """a.mean(axis=None, dtype=None, out=None, keepdims=False) @@ -3130,23 +3159,26 @@ def mean( ) dtype = self._summation_dtype(dtype) + where_array = broadcast_where(where, self.shape) # Do the sum - if out is not None and out.dtype == dtype: - sum_array = self.sum( + sum_array = ( + self.sum( axis=axis, - dtype=dtype, out=out, keepdims=keepdims, - ) - else: - sum_array = self.sum( - axis=axis, dtype=dtype, - keepdims=keepdims, + where=where_array, + ) + if out is not None and out.dtype == dtype + else self.sum( + axis=axis, keepdims=keepdims, dtype=dtype, where=where_array ) + ) - self._normalize_summation(sum_array, axis, dtype) + self._normalize_summation( + sum_array, axis, keepdims=keepdims, where=where_array + ) # Convert to the output we didn't already put it there if out is not None and sum_array is not out: @@ -3156,6 +3188,34 @@ def mean( else: return sum_array + def _nanmean( + self, + axis: Optional[Union[int, tuple[int, ...]]] = None, + dtype: Union[np.dtype[Any], None] = None, + out: Union[ndarray, None] = None, + keepdims: bool = False, + where: Union[ndarray, None] = None, + ) -> ndarray: + from . import _ufunc + + if np.issubdtype(dtype, np.integer) or np.issubdtype(dtype, np.bool_): + return self.mean( + axis=axis, dtype=dtype, out=out, keepdims=keepdims, where=where + ) + + nan_mask = _ufunc.bit_twiddling.bitwise_not( + _ufunc.floating.isnan(self) + ) + if where is not None: + nan_mask &= where + return self.mean( + axis=axis, + dtype=dtype, + out=out, + keepdims=keepdims, + where=nan_mask, + ) + @add_boilerplate() def var( self, @@ -3165,7 +3225,7 @@ def var( ddof: int = 0, keepdims: bool = False, *, - where: Union[bool, ndarray] = True, + where: Union[ndarray, None] = None, ) -> ndarray: """a.var(a, axis=None, dtype=None, out=None, ddof=0, keepdims=False) @@ -3199,10 +3259,11 @@ def var( dtype = self._summation_dtype(dtype) # calculate the mean, but keep the dimensions so that the # mean can be broadcast against the original array - mu = self.mean(axis=axis, dtype=dtype, keepdims=True) + mu = self.mean(axis=axis, dtype=dtype, keepdims=True, where=where) + + where_array = broadcast_where(where, self.shape) # 1D arrays (or equivalent) should benefit from this unary reduction: - # if axis is None or calculate_volume(tuple_pop(self.shape, axis)) == 1: # this is a scalar reduction and we can optimize this as a single # pass through a scalar reduction @@ -3213,7 +3274,7 @@ def var( dtype=dtype, out=out, keepdims=keepdims, - where=where, + where=where_array, args=(mu,), ) else: @@ -3234,10 +3295,16 @@ def var( dtype=dtype, out=out, keepdims=keepdims, - where=where, + where=where_array, ) - self._normalize_summation(result, axis=axis, dtype=dtype, ddof=ddof) + self._normalize_summation( + result, + axis=axis, + ddof=ddof, + keepdims=keepdims, + where=where_array, + ) return result @@ -3248,7 +3315,7 @@ def min( out: Union[ndarray, None] = None, keepdims: bool = False, initial: Union[int, float, None] = None, - where: Union[bool, ndarray] = True, + where: Union[ndarray, None] = None, ) -> ndarray: """a.min(axis=None, out=None, keepdims=False, initial=, where=True) @@ -3346,7 +3413,7 @@ def prod( out: Union[ndarray, None] = None, keepdims: bool = False, initial: Union[int, float, None] = None, - where: Union[bool, ndarray] = True, + where: Union[ndarray, None] = None, ) -> ndarray: """a.prod(axis=None, dtype=None, out=None, keepdims=False, initial=1, where=True) @@ -3716,10 +3783,10 @@ def sum( out: Union[ndarray, None] = None, keepdims: bool = False, initial: Union[int, float, None] = None, - where: Union[bool, ndarray] = True, + where: Union[ndarray, None] = None, ) -> ndarray: """a.sum(axis=None, dtype=None, out=None, keepdims=False, initial=0, - where=True) + where=None) Return the sum of the array elements over the given axis. @@ -3755,6 +3822,34 @@ def sum( where=where, ) + def _nansum( + self, + axis: Any = None, + dtype: Any = None, + out: Union[ndarray, None] = None, + keepdims: bool = False, + initial: Optional[Union[int, float]] = None, + where: Optional[ndarray] = None, + ) -> ndarray: + # Note that np.nansum and np.sum allow complex datatypes + # so there are no "disallowed types" for this API + + if self.dtype.kind in ("f", "c"): + unary_red_code = UnaryRedCode.NANSUM + else: + unary_red_code = UnaryRedCode.SUM + + return self._perform_unary_reduction( + unary_red_code, + self, + axis=axis, + dtype=dtype, + out=out, + keepdims=keepdims, + initial=initial, + where=where, + ) + def swapaxes(self, axis1: Any, axis2: Any) -> ndarray: """a.swapaxes(axis1, axis2) @@ -4064,16 +4159,16 @@ def unique(self) -> ndarray: @classmethod def _get_where_thunk( - cls, where: Union[bool, ndarray], out_shape: NdShape - ) -> Union[Literal[True], NumPyThunk]: - if where is True: - return True - if where is False: - raise RuntimeError("should have caught this earlier") - if not isinstance(where, ndarray) or where.dtype != np.bool_: + cls, where: Union[None, ndarray], out_shape: NdShape + ) -> Union[None, NumPyThunk]: + if where is None: + return where + if ( + not isinstance(where, ndarray) + or where.dtype != np.bool_ + or where.shape != out_shape + ): raise RuntimeError("should have converted this earlier") - if where.shape != out_shape: - raise ValueError("where parameter must have same shape as output") return where._thunk @staticmethod @@ -4126,33 +4221,31 @@ def _perform_unary_op( out: Union[Any, None] = None, extra_args: Any = None, dtype: Union[np.dtype[Any], None] = None, - where: Union[bool, ndarray] = True, out_dtype: Union[np.dtype[Any], None] = None, ) -> ndarray: if out is not None: # If the shapes don't match see if we can broadcast # This will raise an exception if they can't be broadcast together - if isinstance(where, ndarray): - np.broadcast_shapes(src.shape, out.shape, where.shape) - else: - np.broadcast_shapes(src.shape, out.shape) + if np.broadcast_shapes(src.shape, out.shape) != out.shape: + raise ValueError( + f"non-broadcastable output operand with shape {out.shape} " + f"doesn't match the broadcast shape {src.shape}" + ) else: # No output yet, so make one - if isinstance(where, ndarray): - out_shape = np.broadcast_shapes(src.shape, where.shape) - else: - out_shape = src.shape + out_shape = src.shape + if dtype is not None: out = ndarray( shape=out_shape, dtype=dtype, - inputs=(src, where), + inputs=(src,), ) elif out_dtype is not None: out = ndarray( shape=out_shape, dtype=out_dtype, - inputs=(src, where), + inputs=(src,), ) else: out = ndarray( @@ -4162,13 +4255,9 @@ def _perform_unary_op( else np.dtype(np.float32) if src.dtype == np.dtype(np.complex64) else np.dtype(np.float64), - inputs=(src, where), + inputs=(src,), ) - # Quick exit - if where is False: - return out - if out_dtype is None: if out.dtype != src.dtype and not ( op == UnaryOpCode.ABSOLUTE and src.dtype.kind == "c" @@ -4176,12 +4265,12 @@ def _perform_unary_op( temp = ndarray( out.shape, dtype=src.dtype, - inputs=(src, where), + inputs=(src,), ) temp._thunk.unary_op( op, src._thunk, - cls._get_where_thunk(where, out.shape), + True, extra_args, ) out._thunk.convert(temp._thunk) @@ -4189,7 +4278,7 @@ def _perform_unary_op( out._thunk.unary_op( op, src._thunk, - cls._get_where_thunk(where, out.shape), + True, extra_args, ) else: @@ -4197,12 +4286,12 @@ def _perform_unary_op( temp = ndarray( out.shape, dtype=out_dtype, - inputs=(src, where), + inputs=(src,), ) temp._thunk.unary_op( op, src._thunk, - cls._get_where_thunk(where, out.shape), + True, extra_args, ) out._thunk.convert(temp._thunk) @@ -4210,7 +4299,7 @@ def _perform_unary_op( out._thunk.unary_op( op, src._thunk, - cls._get_where_thunk(where, out.shape), + True, extra_args, ) return out @@ -4228,7 +4317,7 @@ def _perform_unary_reduction( keepdims: bool = False, args: Union[Any, None] = None, initial: Union[int, float, None] = None, - where: Union[bool, ndarray] = True, + where: Union[ndarray, None] = None, ) -> ndarray: # When 'res_dtype' is not None, the input and output of the reduction # have different types. Such reduction operators don't take a dtype of @@ -4250,13 +4339,6 @@ def _perform_unary_reduction( # TODO: Need to require initial to be given when the array is empty # or a where mask is given. - if isinstance(where, ndarray): - # The where array has to broadcast to the src.shape - if np.broadcast_shapes(src.shape, where.shape) != src.shape: - raise ValueError( - '"where" array must broadcast against source array ' - "for reduction" - ) if ( op in ( @@ -4304,17 +4386,17 @@ def _perform_unary_reduction( shape=out_shape, dtype=res_dtype, inputs=(src, where) ) - if where: - result._thunk.unary_reduction( - op, - src._thunk, - cls._get_where_thunk(where, result.shape), - axis, - axes, - keepdims, - args, - initial, - ) + where_array = broadcast_where(where, src.shape) + result._thunk.unary_reduction( + op, + src._thunk, + cls._get_where_thunk(where_array, src.shape), + axis, + axes, + keepdims, + args, + initial, + ) if result is not out: out._thunk.convert(result._thunk) diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index a67d9912d..fb288a205 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -1,4 +1,4 @@ -# Copyright 2021-2022 NVIDIA Corporation +# Copyright 2021-2023 NVIDIA Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -3131,7 +3131,7 @@ def unary_op( # Perform a unary reduction operation from one set of dimensions down to # fewer - @auto_convert("src") + @auto_convert("src", "where") def unary_reduction( self, op: UnaryRedCode, @@ -3162,6 +3162,7 @@ def unary_reduction( inputs=[self], ) + is_where = bool(where is not None) # See if we are doing reduction to a point or another region if lhs_array.size == 1: assert axes is None or lhs_array.ndim == rhs_array.ndim - ( @@ -3189,6 +3190,10 @@ def unary_reduction( task.add_input(rhs_array.base) task.add_scalar_arg(op, ty.int32) task.add_scalar_arg(rhs_array.shape, (ty.int64,)) + task.add_scalar_arg(is_where, ty.bool_) + if is_where: + task.add_input(where.base) + task.add_alignment(rhs_array.base, where.base) self.add_arguments(task, args) @@ -3228,6 +3233,10 @@ def unary_reduction( task.add_reduction(result, _UNARY_RED_TO_REDUCTION_OPS[op]) task.add_scalar_arg(axis, ty.int32) task.add_scalar_arg(op, ty.int32) + task.add_scalar_arg(is_where, ty.bool_) + if is_where: + task.add_input(where.base) + task.add_alignment(rhs_array.base, where.base) self.add_arguments(task, args) diff --git a/cunumeric/eager.py b/cunumeric/eager.py index 26fc98016..7c58023ef 100644 --- a/cunumeric/eager.py +++ b/cunumeric/eager.py @@ -1,4 +1,4 @@ -# Copyright 2021-2022 NVIDIA Corporation +# Copyright 2021-2023 NVIDIA Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -1499,6 +1499,8 @@ def unary_reduction( initial, ) return + if where is None: + where = True if op in _UNARY_RED_OPS_WITH_ARG: fn = _UNARY_RED_OPS_WITH_ARG[op] # arg based APIs don't have the following arguments: where, initial @@ -1530,7 +1532,9 @@ def unary_reduction( squared, out=self.array, axis=orig_axis, - where=where, + where=where + if not isinstance(where, EagerArray) + else where.array, keepdims=keepdims, ) elif op == UnaryRedCode.VARIANCE: @@ -1540,7 +1544,9 @@ def unary_reduction( np.sum( squares, axis=orig_axis, - where=where, + where=where + if not isinstance(where, EagerArray) + else where.array, keepdims=keepdims, out=self.array, ) diff --git a/cunumeric/linalg/cholesky.py b/cunumeric/linalg/cholesky.py index 4ff4fe212..eed4c3188 100644 --- a/cunumeric/linalg/cholesky.py +++ b/cunumeric/linalg/cholesky.py @@ -208,7 +208,7 @@ def _batched_cholesky(output: DeferredArray, input: DeferredArray) -> None: # wildly varying memory available depending on the system. # Just use a fixed cutoff to provide some sensible warning. # TODO: find a better way to inform the user dims are too big - context: Context = output.context + context: Context = output.context # type: ignore task = context.create_auto_task(CuNumericOpCode.BATCHED_CHOLESKY) task.add_input(input.base) task.add_output(output.base) diff --git a/cunumeric/module.py b/cunumeric/module.py index d4a801b6a..45d7508b1 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -1,4 +1,4 @@ -# Copyright 2021-2022 NVIDIA Corporation +# Copyright 2021-2023 NVIDIA Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -1487,6 +1487,10 @@ def _broadcast_to( arr = array(arr, copy=False, subok=subok) # 'broadcast_to' returns a read-only view of the original array out_shape = broadcast_shapes(arr.shape, shape) + if out_shape != shape: + raise ValueError( + f"cannot broadcast an array of shape {arr.shape} to {shape}" + ) result = ndarray( shape=out_shape, thunk=arr._thunk.broadcast_to(out_shape), @@ -4939,7 +4943,7 @@ def all( axis: Optional[Union[int, tuple[int, ...]]] = None, out: Optional[ndarray] = None, keepdims: bool = False, - where: bool = True, + where: Optional[ndarray] = None, ) -> ndarray: """ Test whether all array elements along a given axis evaluate to True. @@ -4997,7 +5001,7 @@ def any( axis: Optional[Union[int, tuple[int, ...]]] = None, out: Optional[ndarray] = None, keepdims: bool = False, - where: bool = True, + where: Optional[ndarray] = None, ) -> ndarray: """ Test whether any array element along a given axis evaluates to True. @@ -5257,7 +5261,7 @@ def prod( out: Optional[ndarray] = None, keepdims: bool = False, initial: Optional[Union[int, float]] = None, - where: bool = True, + where: Optional[ndarray] = None, ) -> ndarray: """ @@ -5338,7 +5342,7 @@ def sum( out: Optional[ndarray] = None, keepdims: bool = False, initial: Optional[Union[int, float]] = None, - where: bool = True, + where: Optional[ndarray] = None, ) -> ndarray: """ @@ -5798,7 +5802,7 @@ def nanmin( out: Union[ndarray, None] = None, keepdims: bool = False, initial: Optional[Union[int, float]] = None, - where: Any = True, + where: Optional[ndarray] = None, ) -> ndarray: """ Return minimum of an array or minimum along an axis, ignoring any @@ -5878,8 +5882,8 @@ def nanmin( ) if cunumeric_settings.numpy_compat() and a.dtype.kind == "f": - where = all(isnan(a), axis=axis, keepdims=keepdims) - putmask(out_array, where, np.nan) # type: ignore + all_nan = all(isnan(a), axis=axis, keepdims=keepdims, where=where) + putmask(out_array, all_nan, np.nan) # type: ignore return out_array @@ -5891,7 +5895,7 @@ def nanmax( out: Union[ndarray, None] = None, keepdims: bool = False, initial: Optional[Union[int, float]] = None, - where: Any = True, + where: Optional[ndarray] = None, ) -> ndarray: """ Return the maximum of an array or maximum along an axis, ignoring any @@ -5974,8 +5978,8 @@ def nanmax( ) if cunumeric_settings.numpy_compat() and a.dtype.kind == "f": - where = all(isnan(a), axis=axis, keepdims=keepdims) - putmask(out_array, where, np.nan) # type: ignore + all_nan = all(isnan(a), axis=axis, keepdims=keepdims, where=where) + putmask(out_array, all_nan, np.nan) # type: ignore return out_array @@ -5988,7 +5992,7 @@ def nanprod( out: Union[ndarray, None] = None, keepdims: bool = False, initial: Optional[Union[int, float]] = None, - where: Any = True, + where: Optional[ndarray] = None, ) -> ndarray: """ Return the product of array elements over a given axis treating @@ -6084,7 +6088,7 @@ def nansum( out: Union[ndarray, None] = None, keepdims: bool = False, initial: Optional[Union[int, float]] = None, - where: Any = True, + where: Optional[ndarray] = None, ) -> ndarray: """ Return the sum of array elements over a given axis treating @@ -6151,17 +6155,7 @@ def nansum( Multiple GPUs, Multiple CPUs """ - # Note that np.nansum and np.sum allow complex datatypes - # so there are no "disallowed types" for this API - - if a.dtype.kind in ("f", "c"): - unary_red_code = UnaryRedCode.NANSUM - else: - unary_red_code = UnaryRedCode.SUM - - return a._perform_unary_reduction( - unary_red_code, - a, + return a._nansum( axis=axis, dtype=dtype, out=out, @@ -6248,7 +6242,7 @@ def amax( out: Optional[ndarray] = None, keepdims: bool = False, initial: Optional[Union[int, float]] = None, - where: bool = True, + where: Optional[ndarray] = None, ) -> ndarray: """ @@ -6325,7 +6319,7 @@ def amin( out: Optional[ndarray] = None, keepdims: bool = False, initial: Optional[Union[int, float]] = None, - where: bool = True, + where: Optional[ndarray] = None, ) -> ndarray: """ @@ -7053,14 +7047,7 @@ def count_nonzero( -------- Multiple GPUs, Multiple CPUs """ - if a.size == 0: - return 0 - return ndarray._perform_unary_reduction( - UnaryRedCode.COUNT_NONZERO, - a, - res_dtype=np.dtype(np.uint64), - axis=axis, - ) + return a._count_nonzero(axis) ############ @@ -7077,6 +7064,7 @@ def mean( dtype: Optional[np.dtype[Any]] = None, out: Optional[ndarray] = None, keepdims: bool = False, + where: Optional[ndarray] = None, ) -> ndarray: """ @@ -7118,10 +7106,13 @@ def mean( sub-class' method does not implement `keepdims` any exceptions will be raised. + where : array_like of bool, optional + Elements to include in the mean. + Returns ------- m : ndarray - If `out=None`, returns a new array of the same dtype a above + If `out is None`, returns a new array of the same dtype a above containing the mean values, otherwise a reference to the output array is returned. @@ -7133,7 +7124,76 @@ def mean( -------- Multiple GPUs, Multiple CPUs """ - return a.mean(axis=axis, dtype=dtype, out=out, keepdims=keepdims) + return a.mean( + axis=axis, dtype=dtype, out=out, keepdims=keepdims, where=where + ) + + +@add_boilerplate("a") +def nanmean( + a: ndarray, + axis: Optional[Union[int, tuple[int, ...]]] = None, + dtype: Optional[np.dtype[Any]] = None, + out: Optional[ndarray] = None, + keepdims: bool = False, + where: Optional[ndarray] = None, +) -> ndarray: + """ + + Compute the arithmetic mean along the specified axis, ignoring NaNs. + + Returns the average of the array elements. The average is taken over + the flattened array by default, otherwise over the specified axis. + `float64` intermediate and return values are used for integer inputs. + + Parameters + ---------- + a : array_like + Array containing numbers whose mean is desired. If `a` is not an + array, a conversion is attempted. + axis : None or int or tuple[int], optional + Axis or axes along which the means are computed. The default is to + compute the mean of the flattened array. + + If this is a tuple of ints, a mean is performed over multiple axes, + instead of a single axis or all the axes as before. + dtype : data-type, optional + Type to use in computing the mean. For integer inputs, the default + is `float64`; for floating point inputs, it is the same as the + input dtype. + out : ndarray, optional + Alternate output array in which to place the result. The default + is ``None``; if provided, it must have the same shape as the + expected output, but the type will be cast if necessary. + See `ufuncs-output-type` for more details. + + keepdims : bool, optional + If this is set to True, the axes which are reduced are left + in the result as dimensions with size one. With this option, + the result will broadcast correctly against the input array. + + + where : array_like of bool, optional + Elements to include in the mean. + + Returns + ------- + m : ndarray + If `out is None`, returns a new array of the same dtype as a above + containing the mean values, otherwise a reference to the output + array is returned. + + See Also + -------- + numpy.nanmean + + Availability + -------- + Multiple GPUs, Multiple CPUs + """ + return a._nanmean( + axis=axis, dtype=dtype, out=out, keepdims=keepdims, where=where + ) @add_boilerplate("a") @@ -7145,7 +7205,7 @@ def var( ddof: int = 0, keepdims: bool = False, *, - where: Union[bool, ndarray] = True, + where: Union[ndarray, None] = None, ) -> ndarray: """ Compute the variance along the specified axis. diff --git a/docs/cunumeric/source/api/statistics.rst b/docs/cunumeric/source/api/statistics.rst index 9227c93ae..48f10f19c 100644 --- a/docs/cunumeric/source/api/statistics.rst +++ b/docs/cunumeric/source/api/statistics.rst @@ -10,6 +10,7 @@ Averages and variances :toctree: generated/ mean + nanmean var diff --git a/src/cunumeric/unary/scalar_unary_red.cc b/src/cunumeric/unary/scalar_unary_red.cc index 77f746eb4..c10e065f9 100644 --- a/src/cunumeric/unary/scalar_unary_red.cc +++ b/src/cunumeric/unary/scalar_unary_red.cc @@ -1,4 +1,4 @@ -/* Copyright 2021-2022 NVIDIA Corporation +/* Copyright 2021-2023 NVIDIA Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/src/cunumeric/unary/scalar_unary_red.cu b/src/cunumeric/unary/scalar_unary_red.cu index 6959d05c9..76dcaeb32 100644 --- a/src/cunumeric/unary/scalar_unary_red.cu +++ b/src/cunumeric/unary/scalar_unary_red.cu @@ -1,4 +1,4 @@ -/* Copyright 2021-2022 NVIDIA Corporation +/* Copyright 2021-2023 NVIDIA Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/src/cunumeric/unary/scalar_unary_red.h b/src/cunumeric/unary/scalar_unary_red.h index ef96d1bfc..570c0d605 100644 --- a/src/cunumeric/unary/scalar_unary_red.h +++ b/src/cunumeric/unary/scalar_unary_red.h @@ -1,4 +1,4 @@ -/* Copyright 2021-2022 NVIDIA Corporation +/* Copyright 2021-2023 NVIDIA Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -24,6 +24,7 @@ namespace cunumeric { struct ScalarUnaryRedArgs { const Array& out; const Array& in; + const Array& where; UnaryRedCode op_code; legate::DomainPoint shape; std::vector args; diff --git a/src/cunumeric/unary/scalar_unary_red_omp.cc b/src/cunumeric/unary/scalar_unary_red_omp.cc index c33576355..646f0193a 100644 --- a/src/cunumeric/unary/scalar_unary_red_omp.cc +++ b/src/cunumeric/unary/scalar_unary_red_omp.cc @@ -1,4 +1,4 @@ -/* Copyright 2021-2022 NVIDIA Corporation +/* Copyright 2021-2023 NVIDIA Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. diff --git a/src/cunumeric/unary/scalar_unary_red_template.inl b/src/cunumeric/unary/scalar_unary_red_template.inl index 3ca19e8a7..35173abeb 100644 --- a/src/cunumeric/unary/scalar_unary_red_template.inl +++ b/src/cunumeric/unary/scalar_unary_red_template.inl @@ -1,4 +1,4 @@ -/* Copyright 2021-2022 NVIDIA Corporation +/* Copyright 2021-2023 NVIDIA Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -28,7 +28,7 @@ namespace cunumeric { using namespace legate; -template +template struct ScalarUnaryRed { using OP = UnaryRedOp; using LG_OP = typename OP::OP; @@ -36,6 +36,7 @@ struct ScalarUnaryRed { using RHS = legate_type_of; using OUT = AccessorRD; using IN = AccessorRO; + using WHERE = AccessorRO; IN in; const RHS* inptr; @@ -48,6 +49,8 @@ struct ScalarUnaryRed { RHS to_find; RHS mu; bool dense; + WHERE where; + const bool* whereptr; struct DenseReduction {}; struct SparseReduction {}; @@ -64,42 +67,53 @@ struct ScalarUnaryRed { if constexpr (OP_CODE == UnaryRedCode::CONTAINS) { to_find = args.args[0].scalar(); } if constexpr (OP_CODE == UnaryRedCode::VARIANCE) { mu = args.args[0].scalar(); } + if constexpr (HAS_WHERE) where = args.where.read_accessor(rect); #ifndef LEGATE_BOUNDS_CHECKS // Check to see if this is dense or not if (in.accessor.is_dense_row_major(rect)) { dense = true; inptr = in.ptr(rect); } + if constexpr (HAS_WHERE) { + dense = dense && where.accessor.is_dense_row_major(rect); + if (dense) whereptr = where.ptr(rect); + } #endif } __CUDA_HD__ void operator()(LHS& lhs, size_t idx, LHS identity, DenseReduction) const noexcept { + bool mask = true; + if constexpr (HAS_WHERE) mask = whereptr[idx]; + if constexpr (OP_CODE == UnaryRedCode::CONTAINS) { - if (inptr[idx] == to_find) { lhs = true; } + if (mask && (inptr[idx] == to_find)) { lhs = true; } } else if constexpr (OP_CODE == UnaryRedCode::ARGMAX || OP_CODE == UnaryRedCode::ARGMIN || OP_CODE == UnaryRedCode::NANARGMAX || OP_CODE == UnaryRedCode::NANARGMIN) { auto p = pitches.unflatten(idx, origin); - OP::template fold(lhs, OP::convert(p, shape, identity, inptr[idx])); + if (mask) OP::template fold(lhs, OP::convert(p, shape, identity, inptr[idx])); } else if constexpr (OP_CODE == UnaryRedCode::VARIANCE) { - OP::template fold(lhs, OP::convert(inptr[idx] - mu, identity)); + if (mask) OP::template fold(lhs, OP::convert(inptr[idx] - mu, identity)); } else { - OP::template fold(lhs, OP::convert(inptr[idx], identity)); + if (mask) OP::template fold(lhs, OP::convert(inptr[idx], identity)); } } __CUDA_HD__ void operator()(LHS& lhs, size_t idx, LHS identity, SparseReduction) const noexcept { - auto p = pitches.unflatten(idx, origin); + auto p = pitches.unflatten(idx, origin); + bool mask = true; + if constexpr (HAS_WHERE) mask = where[p]; + if constexpr (OP_CODE == UnaryRedCode::CONTAINS) { - if (in[p] == to_find) { lhs = true; } + if (mask && (in[p] == to_find)) { lhs = true; } } else if constexpr (OP_CODE == UnaryRedCode::ARGMAX || OP_CODE == UnaryRedCode::ARGMIN || OP_CODE == UnaryRedCode::NANARGMAX || OP_CODE == UnaryRedCode::NANARGMIN) { - OP::template fold(lhs, OP::convert(p, shape, identity, in[p])); + if (mask) OP::template fold(lhs, OP::convert(p, shape, identity, in[p])); } else if constexpr (OP_CODE == UnaryRedCode::VARIANCE) { - OP::template fold(lhs, OP::convert(in[p] - mu, identity)); + if (mask) OP::template fold(lhs, OP::convert(in[p] - mu, identity)); } else { - OP::template fold(lhs, OP::convert(in[p], identity)); + if (mask) OP::template fold(lhs, OP::convert(in[p], identity)); } } @@ -120,14 +134,14 @@ struct ScalarUnaryRed { } }; -template +template struct ScalarUnaryRedImpl { template void operator()(ScalarUnaryRedArgs& args) const { // The operation is always valid for contains if constexpr (UnaryRedOp::valid || OP_CODE == UnaryRedCode::CONTAINS) { - ScalarUnaryRed red(args); + ScalarUnaryRed red(args); red.execute(); } } @@ -136,10 +150,13 @@ struct ScalarUnaryRedImpl { template struct ScalarUnaryRedDispatch { template - void operator()(ScalarUnaryRedArgs& args) const + void operator()(ScalarUnaryRedArgs& args, bool has_where) const { auto dim = std::max(1, args.in.dim()); - double_dispatch(dim, args.in.code(), ScalarUnaryRedImpl{}, args); + if (has_where) + double_dispatch(dim, args.in.code(), ScalarUnaryRedImpl{}, args); + else + double_dispatch(dim, args.in.code(), ScalarUnaryRedImpl{}, args); } }; @@ -149,19 +166,28 @@ static void scalar_unary_red_template(TaskContext& context) auto& inputs = context.inputs(); auto& scalars = context.scalars(); + auto op_code = scalars[0].value(); + auto shape = scalars[1].value(); + bool has_where = scalars[2].value(); + size_t start_idx = has_where ? 2 : 1; std::vector extra_args; - for (size_t idx = 1; idx < inputs.size(); ++idx) extra_args.push_back(std::move(inputs[idx])); - - auto op_code = scalars[0].value(); - auto shape = scalars[1].value(); + extra_args.reserve(inputs.size() - start_idx); + for (size_t idx = start_idx; idx < inputs.size(); ++idx) + extra_args.emplace_back(std::move(inputs[idx])); // If the RHS was a scalar, use (1,) as the shape if (shape.dim == 0) { shape.dim = 1; shape[0] = 1; } - ScalarUnaryRedArgs args{ - context.reductions()[0], inputs[0], op_code, shape, std::move(extra_args)}; - op_dispatch(args.op_code, ScalarUnaryRedDispatch{}, args); + + Array dummy_where; + ScalarUnaryRedArgs args{context.reductions()[0], + inputs[0], + has_where ? inputs[1] : dummy_where, + op_code, + shape, + std::move(extra_args)}; + op_dispatch(args.op_code, ScalarUnaryRedDispatch{}, args, has_where); } } // namespace cunumeric diff --git a/src/cunumeric/unary/unary_red.cc b/src/cunumeric/unary/unary_red.cc index bec85fc6f..b37d1a4b2 100644 --- a/src/cunumeric/unary/unary_red.cc +++ b/src/cunumeric/unary/unary_red.cc @@ -1,4 +1,4 @@ -/* Copyright 2021-2022 NVIDIA Corporation +/* Copyright 2021-2023 NVIDIA Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -21,23 +21,28 @@ namespace cunumeric { using namespace legate; -template -struct UnaryRedImplBody { +template +struct UnaryRedImplBody { using OP = UnaryRedOp; using LG_OP = typename OP::OP; using RHS = legate_type_of; void operator()(AccessorRD lhs, AccessorRO rhs, + AccessorRO where, const Rect& rect, const Pitches& pitches, int collapsed_dim, size_t volume) const { for (size_t idx = 0; idx < volume; ++idx) { - auto point = pitches.unflatten(idx, rect.lo); - auto identity = LG_OP::identity; - lhs.reduce(point, OP::convert(point, collapsed_dim, identity, rhs[point])); + auto point = pitches.unflatten(idx, rect.lo); + bool mask = true; + if constexpr (HAS_WHERE) mask = where[point]; + if (mask) { + auto identity = LG_OP::identity; + lhs.reduce(point, OP::convert(point, collapsed_dim, identity, rhs[point])); + } } } }; diff --git a/src/cunumeric/unary/unary_red.cu b/src/cunumeric/unary/unary_red.cu index b417b3638..b5e0e5eb1 100644 --- a/src/cunumeric/unary/unary_red.cu +++ b/src/cunumeric/unary/unary_red.cu @@ -1,4 +1,4 @@ -/* Copyright 2021-2022 NVIDIA Corporation +/* Copyright 2021-2023 NVIDIA Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -201,26 +201,14 @@ std::ostream& operator<<(std::ostream& os, const ThreadBlocks& blocks) return os; } -template -static __device__ __forceinline__ Point local_reduce(LHS& result, - AccessorRO in, - LHS identity, - const ThreadBlocks& blocks, - const Rect& domain, - int32_t collapsed_dim) +template +static void __device__ __forceinline__ collapse_dims(LHS& result, + Point& point, + const Rect& domain, + int32_t collapsed_dim, + LHS identity, + coord_t tid) { - const coord_t tid = threadIdx.x; - const coord_t bid = blockIdx.x; - - Point point = blocks.point(bid, tid, domain.lo); - if (!domain.contains(point)) return point; - - while (point[collapsed_dim] <= domain.hi[collapsed_dim]) { - LHS value = OP::convert(point, collapsed_dim, identity, in[point]); - REDOP::template fold(result, value); - blocks.next_point(point); - } - #if __CUDA_ARCH__ >= 700 // If we're collapsing the innermost dimension, we perform some optimization // with shared memory to reduce memory traffic due to atomic updates @@ -276,26 +264,55 @@ static __device__ __forceinline__ Point local_reduce(LHS& result, // put points back in the bounds to appease the checks. point[collapsed_dim] = domain.lo[collapsed_dim]; #endif +} + +template +static __device__ __forceinline__ Point local_reduce(LHS& result, + AccessorRO in, + AccessorRO where, + LHS identity, + const ThreadBlocks& blocks, + const Rect& domain, + int32_t collapsed_dim) +{ + const coord_t tid = threadIdx.x; + const coord_t bid = blockIdx.x; + + Point point = blocks.point(bid, tid, domain.lo); + if (!domain.contains(point)) return point; + + bool mask = true; + while (point[collapsed_dim] <= domain.hi[collapsed_dim]) { + if constexpr (HAS_WHERE) mask = where[point]; + if (mask) { + LHS value = OP::convert(point, collapsed_dim, identity, in[point]); + REDOP::template fold(result, value); + } + blocks.next_point(point); + } + + collapse_dims(result, point, domain, collapsed_dim, identity, tid); return point; } -template +template static __global__ void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) reduce_with_rd_acc(AccessorRD out, AccessorRO in, + AccessorRO where, LHS identity, ThreadBlocks blocks, Rect domain, int32_t collapsed_dim) { auto result = identity; - auto point = - local_reduce(result, in, identity, blocks, domain, collapsed_dim); + auto point = local_reduce( + result, in, where, identity, blocks, domain, collapsed_dim); if (result != identity) out.reduce(point, result); } -template -struct UnaryRedImplBody { +template +struct UnaryRedImplBody { using OP = UnaryRedOp; using LG_OP = typename OP::OP; using RHS = legate_type_of; @@ -303,12 +320,13 @@ struct UnaryRedImplBody { void operator()(AccessorRD lhs, AccessorRO rhs, + AccessorRO where, const Rect& rect, const Pitches& pitches, int collapsed_dim, size_t volume) const { - auto Kernel = reduce_with_rd_acc; + auto Kernel = reduce_with_rd_acc; auto stream = get_cached_stream(); ThreadBlocks blocks; @@ -316,7 +334,7 @@ struct UnaryRedImplBody { blocks.compute_maximum_concurrency(reinterpret_cast(Kernel)); Kernel<<>>( - lhs, rhs, LG_OP::identity, blocks, rect, collapsed_dim); + lhs, rhs, where, LG_OP::identity, blocks, rect, collapsed_dim); CHECK_CUDA_STREAM(stream); } }; diff --git a/src/cunumeric/unary/unary_red.h b/src/cunumeric/unary/unary_red.h index bea848468..a7b44584f 100644 --- a/src/cunumeric/unary/unary_red.h +++ b/src/cunumeric/unary/unary_red.h @@ -24,6 +24,7 @@ namespace cunumeric { struct UnaryRedArgs { const Array& lhs; const Array& rhs; + const Array& where; int32_t collapsed_dim; UnaryRedCode op_code; }; diff --git a/src/cunumeric/unary/unary_red_omp.cc b/src/cunumeric/unary/unary_red_omp.cc index 21ba49d4a..fd9ccce60 100644 --- a/src/cunumeric/unary/unary_red_omp.cc +++ b/src/cunumeric/unary/unary_red_omp.cc @@ -1,4 +1,4 @@ -/* Copyright 2021-2022 NVIDIA Corporation +/* Copyright 2021-2023 NVIDIA Corporation * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -72,14 +72,15 @@ class Splitter { size_t pitches_[DIM]; }; -template -struct UnaryRedImplBody { +template +struct UnaryRedImplBody { using OP = UnaryRedOp; using LG_OP = typename OP::OP; using RHS = legate_type_of; void operator()(AccessorRD lhs, AccessorRO rhs, + AccessorRO where, const Rect& rect, const Pitches& pitches, int collapsed_dim, @@ -89,12 +90,17 @@ struct UnaryRedImplBody { auto split = splitter.split(rect, collapsed_dim); #pragma omp parallel for schedule(static) - for (size_t o_idx = 0; o_idx < split.outer; ++o_idx) + for (size_t o_idx = 0; o_idx < split.outer; ++o_idx) { for (size_t i_idx = 0; i_idx < split.inner; ++i_idx) { - auto point = splitter.combine(o_idx, i_idx, rect.lo); - auto identity = LG_OP::identity; - lhs.reduce(point, OP::convert(point, collapsed_dim, identity, rhs[point])); + auto point = splitter.combine(o_idx, i_idx, rect.lo); + bool mask = true; + if constexpr (HAS_WHERE) mask = where[point]; + if (mask) { + auto identity = LG_OP::identity; + lhs.reduce(point, OP::convert(point, collapsed_dim, identity, rhs[point])); + } } + } } }; diff --git a/src/cunumeric/unary/unary_red_template.inl b/src/cunumeric/unary/unary_red_template.inl index 1e3b298d3..aa038384f 100644 --- a/src/cunumeric/unary/unary_red_template.inl +++ b/src/cunumeric/unary/unary_red_template.inl @@ -27,10 +27,10 @@ namespace cunumeric { using namespace legate; -template +template struct UnaryRedImplBody; -template +template struct UnaryRedImpl { template (rect); auto lhs = args.lhs.reduce_accessor(rect); - UnaryRedImplBody()( - lhs, rhs, rect, pitches, args.collapsed_dim, volume); + + AccessorRO where; + if constexpr (HAS_WHERE) { where = args.where.read_accessor(rect); } + UnaryRedImplBody()( + lhs, rhs, where, rect, pitches, args.collapsed_dim, volume); } template +template struct UnaryRedDispatch { template void operator()(UnaryRedArgs& args) const { auto dim = std::max(1, args.rhs.dim()); - return double_dispatch(dim, args.rhs.code(), UnaryRedImpl{}, args); + return double_dispatch(dim, args.rhs.code(), UnaryRedImpl{}, args); } }; @@ -78,10 +81,18 @@ static void unary_red_template(TaskContext& context) auto& inputs = context.inputs(); auto& reductions = context.reductions(); auto& scalars = context.scalars(); - - UnaryRedArgs args{ - reductions[0], inputs[0], scalars[0].value(), scalars[1].value()}; - op_dispatch(args.op_code, UnaryRedDispatch{}, args); + bool has_where = scalars[2].value(); + Array dummy_where; + UnaryRedArgs args{reductions[0], + inputs[0], + has_where ? inputs[1] : dummy_where, + scalars[0].value(), + scalars[1].value()}; + if (has_where) { + op_dispatch(args.op_code, UnaryRedDispatch{}, args); + } else { + op_dispatch(args.op_code, UnaryRedDispatch{}, args); + } } } // namespace cunumeric diff --git a/tests/integration/test_logical.py b/tests/integration/test_logical.py index ca9b99220..b0f83aaa6 100644 --- a/tests/integration/test_logical.py +++ b/tests/integration/test_logical.py @@ -102,19 +102,24 @@ def test_nd_inputs(ndim, func): assert np.array_equal(out_np, out_num) -@pytest.mark.skip def test_where(): - # "the `where` parameter is currently not supported" - x = np.array([[True, True, False], [True, True, True]]) y = np.array([[True, False], [True, True]]) cy = num.array(y) + # where needs to be broadcasted assert num.array_equal( - num.all(cy, where=[True, False]), np.all(x, where=[True, False]) + num.all(cy, where=[True, False]), np.all(y, where=[True, False]) ) assert num.array_equal( num.any(cy, where=[[True], [False]]), - np.any(x, where=[[True], [False]]), + np.any(y, where=[[True], [False]]), + ) + + # Where is a boolean + assert num.array_equal(num.all(cy, where=True), np.all(y, where=True)) + assert num.array_equal( + num.any(cy, where=False), + np.any(y, where=False), ) diff --git a/tests/integration/test_mean.py b/tests/integration/test_mean.py index 0f9064280..40092455c 100755 --- a/tests/integration/test_mean.py +++ b/tests/integration/test_mean.py @@ -1,4 +1,4 @@ -# Copyright 2021-2022 NVIDIA Corporation +# Copyright 2021-2023 NVIDIA Corporation # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -60,13 +60,43 @@ def test_scalar(val): assert np.array_equal(res_np, res_num) +@pytest.mark.parametrize("val", (0.0, 10.0, -5, 1 + 1j)) +def test_scalar_where(val): + res_np = np.mean(val, where=True) + res_num = num.mean(val, where=True) + assert np.array_equal(res_np, res_num) + + @pytest.mark.parametrize("size", NO_EMPTY_SIZE) def test_basic(size): arr_np = np.random.randint(-5, 5, size=size) arr_num = num.array(arr_np) res_np = np.mean(arr_np) res_num = num.mean(arr_num) - np.array_equal(res_np, res_num) + assert np.array_equal(res_np, res_num) + + +@pytest.mark.parametrize("size", NO_EMPTY_SIZE) +def test_basic_where(size): + arr_np = np.random.randint(-5, 5, size=size) + arr_num = num.array(arr_np) + where_np = arr_np % 2 + where_np = arr_np.astype(bool) + where_num = num.array(where_np) + res_np = np.mean(arr_np, where=where_np) + res_num = num.mean(arr_num, where=where_num) + assert np.array_equal(res_np, res_num, equal_nan=True) + + +@pytest.mark.parametrize("size", NO_EMPTY_SIZE) +def test_where_broadcast(size): + arr_np = np.random.randint(-5, 5, size=size) + arr_num = num.array(arr_np) + where_np = np.zeros((1,), bool) + where_num = num.array(where_np) + res_np = np.mean(arr_np, where=where_np) + res_num = num.mean(arr_num, where=where_num) + assert np.array_equal(res_np, res_num, equal_nan=True) @pytest.mark.xfail @@ -94,6 +124,20 @@ def test_axis_keepdims(size, keepdims): assert np.array_equal(out_np, out_num) +@pytest.mark.parametrize("size", NO_EMPTY_SIZE) +def test_axis_where(size): + arr_np = np.random.randint(-5, 5, size=size) + arr_num = num.array(arr_np) + where_np = arr_np % 2 + where_np = arr_np.astype(bool) + where_num = num.array(where_np) + ndim = arr_np.ndim + for axis in range(-ndim, ndim): + out_np = np.mean(arr_np, axis=axis, where=where_np) + out_num = num.mean(arr_num, axis=axis, where=where_num) + assert np.array_equal(out_np, out_num, equal_nan=True) + + @pytest.mark.parametrize("array_dt", (np.int32, np.float32, np.complex64)) @pytest.mark.parametrize("dt", (np.int32, np.float32, np.complex64)) @pytest.mark.parametrize("size", NO_EMPTY_SIZE) diff --git a/tests/integration/test_nan_reduction.py b/tests/integration/test_nan_reduction.py index 34bbd1447..e57a46d0b 100644 --- a/tests/integration/test_nan_reduction.py +++ b/tests/integration/test_nan_reduction.py @@ -291,6 +291,49 @@ def test_all_nans_nansum(self, ndim): assert out_num == 0.0 + def test_where(self): + arr = [[1, np.nan, 3], [2, np.nan, 4]] + out_np = np.nansum(arr, where=[False, True, True]) + out_num = num.nansum(arr, where=[False, True, True]) + assert np.allclose(out_np, out_num) + + out_np = np.nanprod(arr, where=[False, True, True]) + out_num = num.nanprod(arr, where=[False, True, True]) + assert np.allclose(out_np, out_num) + + out_np = np.nanmax( + arr, where=[[False, True, True], [False, False, True]], initial=-1 + ) + out_num = num.nanmax( + arr, where=[[False, True, True], [False, False, True]], initial=-1 + ) + assert np.allclose(out_np, out_num) + + out_np = np.nanmin( + arr, where=[[False, True, True], [False, True, True]], initial=10 + ) + out_num = num.nanmin( + arr, where=[[False, True, True], [False, True, True]], initial=10 + ) + assert np.allclose(out_np, out_num) + + # where is a boolean + out_np = np.nansum(arr, where=True) + out_num = num.nansum(arr, where=True) + assert np.allclose(out_np, out_num) + + out_np = np.nanprod(arr, where=False) + out_num = num.nanprod(arr, where=False) + assert np.allclose(out_np, out_num) + + out_np = np.nanmax(arr, where=True, initial=-1) + out_num = num.nanmax(arr, where=True, initial=-1) + assert np.allclose(out_np, out_num) + + out_np = np.nanmin(arr, where=True, initial=10) + out_num = num.nanmin(arr, where=True, initial=10) + assert np.allclose(out_np, out_num) + class TestCornerCases: """ diff --git a/tests/integration/test_nanmean.py b/tests/integration/test_nanmean.py new file mode 100755 index 000000000..98962842b --- /dev/null +++ b/tests/integration/test_nanmean.py @@ -0,0 +1,154 @@ +# Copyright 2021-2023 NVIDIA Corporation +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import numpy as np +import pytest + +import cunumeric as num + +DIM = 7 + +NO_EMPTY_SIZE = ( + (1,), + (DIM,), + (1, 1), + (1, DIM), + (DIM, 1), + (DIM, DIM), + (1, 1, 1), + (DIM, 1, 1), + (1, DIM, 1), + (1, 1, DIM), + (DIM, DIM, DIM), +) + + +def gen_out_shape(size, axis): + if axis is None: + return () + if axis < 0: + axis += len(size) + if axis >= 0 and axis < len(size): + return size[:axis] + size[axis + 1 :] + else: + return -1 + + +@pytest.mark.parametrize("arr", ([], [[], []])) +def test_empty_arr(arr): + res_np = np.nanmean(arr) + res_num = num.nanmean(arr) + assert np.isnan(res_np) and np.isnan(res_num) + + +@pytest.mark.parametrize("val", (np.nan, 0.0, 10.0, -5, 1 + 1j)) +def test_scalar(val): + res_np = np.nanmean(val) + res_num = num.nanmean(val) + assert np.array_equal(res_np, res_num, equal_nan=True) + + +@pytest.mark.parametrize("val", (np.nan, 0.0, 10.0, -5, 1 + 1j)) +def test_scalar_where(val): + res_np = np.nanmean(val, where=True) + res_num = num.nanmean(val, where=True) + assert np.array_equal(res_np, res_num, equal_nan=True) + + +@pytest.mark.parametrize("size", NO_EMPTY_SIZE) +def test_basic(size): + arr_np = np.random.randint(-5, 5, size=size).astype(float) + arr_np[arr_np % 2 == 0] = np.nan + arr_num = num.array(arr_np) + res_np = np.nanmean(arr_np) + res_num = num.nanmean(arr_num) + assert np.array_equal(res_np, res_num, equal_nan=True) + + +@pytest.mark.parametrize("size", NO_EMPTY_SIZE) +def test_basic_where(size): + arr_np = np.random.randint(-5, 5, size=size).astype(float) + arr_np[arr_np % 2 == 0] = np.nan + arr_num = num.array(arr_np) + where_np = arr_np % 2 + where_np = arr_np.astype(bool) + where_num = num.array(where_np) + res_np = np.nanmean(arr_np, where=where_np) + res_num = num.nanmean(arr_num, where=where_num) + assert np.array_equal(res_np, res_num, equal_nan=True) + + +@pytest.mark.xfail +@pytest.mark.parametrize("axis", ((-3, -1), (-1, 0), (-2, 2), (0, 2))) +def test_axis_tuple(axis): + # In Numpy, it pass + # In cuNumeric, it raises NotImplementedError + size = (3, 4, 7) + arr_np = np.random.randint(-5, 5, size=size).astype(float) + arr_np[arr_np % 2 == 1] = np.nan + arr_num = num.array(arr_np) + out_np = np.nanmean(arr_np, axis=axis) + out_num = num.nanmean(arr_num, axis=axis) + assert np.array_equal(out_np, out_num, equal_nan=True) + + +@pytest.mark.parametrize("keepdims", (False, True)) +@pytest.mark.parametrize("size", NO_EMPTY_SIZE) +def test_axis_keepdims(size, keepdims): + arr_np = np.random.randint(-5, 5, size=size).astype(float) + arr_np[arr_np % 2 == 1] = np.nan + arr_num = num.array(arr_np) + ndim = arr_np.ndim + for axis in range(-ndim, ndim): + out_np = np.nanmean(arr_np, axis=axis, keepdims=keepdims) + out_num = num.nanmean(arr_num, axis=axis, keepdims=keepdims) + assert np.array_equal(out_np, out_num, equal_nan=True) + + +@pytest.mark.parametrize("size", NO_EMPTY_SIZE) +def test_axis_where(size): + arr_np = np.random.randint(-5, 5, size=size).astype(float) + arr_np[arr_np % 2 == 0] = np.nan + arr_num = num.array(arr_np) + where_np = arr_np[arr_np % 2 == 1] % 2 + where_np = arr_np.astype(bool) + where_num = num.array(where_np) + ndim = arr_np.ndim + for axis in range(-ndim, ndim): + out_np = np.nanmean(arr_np, axis=axis, where=where_np) + out_num = num.nanmean(arr_num, axis=axis, where=where_num) + assert np.array_equal(out_np, out_num, equal_nan=True) + + +@pytest.mark.parametrize("out_dt", (np.float32, np.complex128)) +@pytest.mark.parametrize("size", NO_EMPTY_SIZE) +def test_out(size, out_dt): + arr_np = np.random.randint(-5, 5, size=size).astype(float) + arr_np[arr_np % 2 == 0] = np.nan + arr_num = num.array(arr_np) + ndim = arr_np.ndim + for axis in (-1, ndim - 1, None): + out_shape = gen_out_shape(size, axis) + out_np = np.empty(out_shape, dtype=out_dt) + out_num = num.empty(out_shape, dtype=out_dt) + np.nanmean(arr_np, axis=axis, out=out_np) + num.nanmean(arr_num, axis=axis, out=out_num) + np.array_equal(out_np, out_num, equal_nan=True) + + +if __name__ == "__main__": + import sys + + sys.exit(pytest.main(sys.argv)) diff --git a/tests/integration/test_prod.py b/tests/integration/test_prod.py index ab0f4def8..c004c95a3 100644 --- a/tests/integration/test_prod.py +++ b/tests/integration/test_prod.py @@ -147,15 +147,21 @@ def test_initial_empty_array(self): out_np = np.prod(arr_np, initial=initial_value) assert allclose(out_np, out_num) - @pytest.mark.xfail def test_where(self): arr = [[1, 2], [3, 4]] - out_np = np.prod(arr, where=[False, True]) # return 8 - # cuNumeric raises NotImplementedError: - # the `where` parameter is currently not supported + out_np = np.prod(arr, where=[False, True]) out_num = num.prod(arr, where=[False, True]) assert allclose(out_np, out_num) + # where is boolean + out_np = np.prod(arr, where=True) + out_num = num.prod(arr, where=True) + assert allclose(out_np, out_num) + + out_np = np.prod(arr, where=False) + out_num = num.prod(arr, where=False) + assert allclose(out_np, out_num) + class TestProdPositive(object): """ diff --git a/tests/integration/test_reduction.py b/tests/integration/test_reduction.py index a7a89a6af..f3379265b 100644 --- a/tests/integration/test_reduction.py +++ b/tests/integration/test_reduction.py @@ -151,15 +151,21 @@ def test_initial_empty_array(self): out_np = np.sum(arr_np, initial=initial_value) # return initial_value assert allclose(out_np, out_num) - @pytest.mark.xfail def test_where(self): arr = [[1, 2], [3, 4]] out_np = np.sum(arr, where=[False, True]) # return 6 - # cuNumeric raises NotImplementedError: - # "the `where` parameter is currently not supported" out_num = num.sum(arr, where=[False, True]) assert allclose(out_np, out_num) + # where is a boolean + out_np = np.sum(arr, where=True) + out_num = num.sum(arr, where=True) + assert allclose(out_np, out_num) + + out_np = np.sum(arr, where=False) + out_num = num.sum(arr, where=False) + assert allclose(out_np, out_num) + class TestSumPositive(object): """ diff --git a/tests/integration/test_stats.py b/tests/integration/test_stats.py index dfa1b0fa3..256a6b77f 100644 --- a/tests/integration/test_stats.py +++ b/tests/integration/test_stats.py @@ -39,7 +39,7 @@ def check_result(in_np, out_np, out_num, **isclose_kwargs): is_negative_test = False result = ( - allclose(out_np, out_num, **isclose_kwargs) + allclose(out_np, out_num, equal_nan=True, **isclose_kwargs) and out_np.dtype == out_num.dtype ) if not result and not is_negative_test: @@ -131,6 +131,24 @@ def test_var_default_shape(dtype, ddof, axis, keepdims): check_op(op_np, op_num, np_in, dtype) +@pytest.mark.parametrize("dtype", dtypes) +@pytest.mark.parametrize("ddof", [0, 1]) +@pytest.mark.parametrize("axis", [None, 0, 1]) +@pytest.mark.parametrize("keepdims", [False, True]) +def test_var_where(dtype, ddof, axis, keepdims): + np_in = get_op_input(astype=dtype) + where = (np_in.astype(int) % 2).astype(bool) + + op_np = functools.partial( + np.var, ddof=ddof, axis=axis, keepdims=keepdims, where=where + ) + op_num = functools.partial( + num.var, ddof=ddof, axis=axis, keepdims=keepdims, where=where + ) + + check_op(op_np, op_num, np_in, dtype) + + @pytest.mark.parametrize("dtype", dtypes) @pytest.mark.parametrize("ddof", [0, 1]) @pytest.mark.parametrize("axis", [None, 0, 1, 2])