Skip to content

Commit

Permalink
removing dtype from _normalize_summation
Browse files Browse the repository at this point in the history
  • Loading branch information
ipdemes committed Dec 2, 2023
1 parent 7c1fdaa commit 869b176
Showing 1 changed file with 5 additions and 8 deletions.
13 changes: 5 additions & 8 deletions cunumeric/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3094,21 +3094,19 @@ def _normalize_summation(
self,
sum_array: Any,
axis: Any,
dtype: np.dtype[Any],
ddof: int = 0,
keepdims: bool = False,
where: Union[ndarray, None] = None,
) -> None:
dtype = sum_array.dtype
if axis is None:
if where is not None:
divisor = where._count_nonzero() - ddof
else:
divisor = reduce(lambda x, y: x * y, self.shape, 1) - ddof
else:
if where is not None:
divisor = where.sum(
axis=axis, dtype=sum_array.dtype, keepdims=keepdims
)
divisor = where.sum(axis=axis, dtype=dtype, keepdims=keepdims)
if ddof != 0 and not np.isscalar(divisor):
mask = divisor != 0
values = divisor - ddof
Expand All @@ -3121,9 +3119,9 @@ def _normalize_summation(
# Divide by the number of things in the collapsed dimensions
# Pick the right kinds of division based on the dtype
if isinstance(divisor, ndarray):
divisor = divisor.astype(sum_array.dtype)
divisor = divisor.astype(dtype)
else:
divisor = np.array(divisor, dtype=sum_array.dtype) # type: ignore [assignment] # noqa
divisor = np.array(divisor, dtype=dtype) # type: ignore [assignment] # noqa

if dtype.kind == "f" or dtype.kind == "c":
sum_array.__itruediv__(divisor)
Expand Down Expand Up @@ -3179,7 +3177,7 @@ def mean(
)

self._normalize_summation(
sum_array, axis, dtype, keepdims=keepdims, where=where_array
sum_array, axis, keepdims=keepdims, where=where_array
)

# Convert to the output we didn't already put it there
Expand Down Expand Up @@ -3303,7 +3301,6 @@ def var(
self._normalize_summation(
result,
axis=axis,
dtype=result.dtype,
ddof=ddof,
keepdims=keepdims,
where=where_array,
Expand Down

0 comments on commit 869b176

Please sign in to comment.