From 869b1762aa1dea0f9e4e9e70f48c4cdcd5d4d914 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Fri, 1 Dec 2023 18:54:50 -0800 Subject: [PATCH] removing dtype from _normalize_summation --- cunumeric/array.py | 13 +++++-------- 1 file changed, 5 insertions(+), 8 deletions(-) diff --git a/cunumeric/array.py b/cunumeric/array.py index 3da80916c..8bfc5178a 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -3094,11 +3094,11 @@ def _normalize_summation( self, sum_array: Any, axis: Any, - dtype: np.dtype[Any], ddof: int = 0, keepdims: bool = False, where: Union[ndarray, None] = None, ) -> None: + dtype = sum_array.dtype if axis is None: if where is not None: divisor = where._count_nonzero() - ddof @@ -3106,9 +3106,7 @@ def _normalize_summation( divisor = reduce(lambda x, y: x * y, self.shape, 1) - ddof else: if where is not None: - divisor = where.sum( - axis=axis, dtype=sum_array.dtype, keepdims=keepdims - ) + divisor = where.sum(axis=axis, dtype=dtype, keepdims=keepdims) if ddof != 0 and not np.isscalar(divisor): mask = divisor != 0 values = divisor - ddof @@ -3121,9 +3119,9 @@ def _normalize_summation( # Divide by the number of things in the collapsed dimensions # Pick the right kinds of division based on the dtype if isinstance(divisor, ndarray): - divisor = divisor.astype(sum_array.dtype) + divisor = divisor.astype(dtype) else: - divisor = np.array(divisor, dtype=sum_array.dtype) # type: ignore [assignment] # noqa + divisor = np.array(divisor, dtype=dtype) # type: ignore [assignment] # noqa if dtype.kind == "f" or dtype.kind == "c": sum_array.__itruediv__(divisor) @@ -3179,7 +3177,7 @@ def mean( ) self._normalize_summation( - sum_array, axis, dtype, keepdims=keepdims, where=where_array + sum_array, axis, keepdims=keepdims, where=where_array ) # Convert to the output we didn't already put it there @@ -3303,7 +3301,6 @@ def var( self._normalize_summation( result, axis=axis, - dtype=result.dtype, ddof=ddof, keepdims=keepdims, where=where_array,