Skip to content

Commit

Permalink
adding dtype back to sum
Browse files Browse the repository at this point in the history
  • Loading branch information
ipdemes committed Nov 17, 2023
1 parent f504f05 commit b15f656
Showing 1 changed file with 23 additions and 10 deletions.
33 changes: 23 additions & 10 deletions cunumeric/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3127,39 +3127,52 @@ def mean(
# Do the sum
sum_array = (
self._nansum(
axis=axis, out=out, keepdims=keepdims, where=where_array
axis=axis,
out=out,
keepdims=keepdims,
dtype=dtype,
where=where_array,
)
if out is not None and out.dtype == dtype and ignore_nan
else self.sum(
axis=axis, out=out, keepdims=keepdims, where=where_array
axis=axis,
out=out,
keepdims=keepdims,
dtype=dtype,
where=where_array,
)
if out is not None and out.dtype == dtype
else self._nansum(axis=axis, keepdims=keepdims, where=where_array)
else self._nansum(
axis=axis, keepdims=keepdims, dtype=dtype, where=where_array
)
if ignore_nan
else self.sum(axis=axis, keepdims=keepdims, where=where_array)
else self.sum(
axis=axis, keepdims=keepdims, dtype=dtype, where=where_array
)
)

if axis is None:
if where_array is not None:
divisor = where_array._count_nonzero()
else:
divisor = np.array(reduce(lambda x, y: x * y, self.shape, 1))
divisor = reduce(lambda x, y: x * y, self.shape, 1)

else:
if where_array is not None:
divisor = where_array.sum(axis=axis, keepdims=keepdims)
divisor = where_array.sum(
axis=axis, dtype=dtype, keepdims=keepdims
)
else:
divisor = np.array(self.shape[axis])
divisor = self.shape[axis]

# Divide by the number of things in the collapsed dimensions
# Pick the right kinds of division based on the dtype
sum_array = sum_array.astype(dtype)
if dtype.kind == "f" or dtype.kind == "c":
sum_array.__itruediv__(
divisor.astype(dtype),
divisor,
)
else:
sum_array.__ifloordiv__(divisor.astype(dtype))
sum_array.__ifloordiv__(divisor)
# Convert to the output we didn't already put it there
if out is not None and sum_array is not out:
assert out.dtype != sum_array.dtype
Expand Down

0 comments on commit b15f656

Please sign in to comment.