Skip to content

Commit

Permalink
fixing var for the case when where is not None
Browse files Browse the repository at this point in the history
  • Loading branch information
ipdemes committed Nov 30, 2023
1 parent b39364b commit bdb49e9
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 8 deletions.
31 changes: 24 additions & 7 deletions cunumeric/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3119,16 +3119,25 @@ def _normalize_summation(
divisor = reduce(lambda x, y: x * y, self.shape, 1) - ddof
else:
if where is not None:
divisor = (
where.sum(axis=axis, dtype=dtype, keepdims=keepdims) - ddof
divisor = where.sum(
axis=axis, dtype=sum_array.dtype, keepdims=keepdims
)
if ddof != 0 and not np.isscalar(divisor):
mask = divisor != 0
values = divisor - ddof
divisor._thunk.putmask(mask._thunk, values._thunk)
else:
divisor -= ddof
else:
divisor = self.shape[axis] - ddof

# Divide by the number of things in the collapsed dimensions
# Pick the right kinds of division based on the dtype
if np.ndim(divisor) == 0:
if isinstance(divisor, ndarray):
divisor = divisor.astype(sum_array.dtype)
else:
divisor = np.array(divisor, dtype=sum_array.dtype) # type: ignore [assignment] # noqa

if dtype.kind == "f" or dtype.kind == "c":
sum_array.__itruediv__(divisor)
else:
Expand Down Expand Up @@ -3267,8 +3276,9 @@ def var(
# mean can be broadcast against the original array
mu = self.mean(axis=axis, dtype=dtype, keepdims=True, where=where)

where_array = broadcast_where(where, self.shape)

# 1D arrays (or equivalent) should benefit from this unary reduction:
#
if axis is None or calculate_volume(tuple_pop(self.shape, axis)) == 1:
# this is a scalar reduction and we can optimize this as a single
# pass through a scalar reduction
Expand All @@ -3279,7 +3289,7 @@ def var(
dtype=dtype,
out=out,
keepdims=keepdims,
where=where,
where=where_array,
args=(mu,),
)
else:
Expand All @@ -3300,10 +3310,17 @@ def var(
dtype=dtype,
out=out,
keepdims=keepdims,
where=where,
where=where_array,
)

self._normalize_summation(result, axis=axis, dtype=dtype, ddof=ddof)
self._normalize_summation(
result,
axis=axis,
dtype=result.dtype,
ddof=ddof,
keepdims=keepdims,
where=where_array,
)

return result

Expand Down
20 changes: 19 additions & 1 deletion tests/integration/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def check_result(in_np, out_np, out_num, **isclose_kwargs):
is_negative_test = False

result = (
allclose(out_np, out_num, **isclose_kwargs)
allclose(out_np, out_num, equal_nan=True, **isclose_kwargs)
and out_np.dtype == out_num.dtype
)
if not result and not is_negative_test:
Expand Down Expand Up @@ -131,6 +131,24 @@ def test_var_default_shape(dtype, ddof, axis, keepdims):
check_op(op_np, op_num, np_in, dtype)


@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("ddof", [0, 1])
@pytest.mark.parametrize("axis", [None, 0, 1])
@pytest.mark.parametrize("keepdims", [False, True])
def test_var_where(dtype, ddof, axis, keepdims):
np_in = get_op_input(astype=dtype)
where = (np_in.astype(int) % 2).astype(bool)

op_np = functools.partial(
np.var, ddof=ddof, axis=axis, keepdims=keepdims, where=where
)
op_num = functools.partial(
num.var, ddof=ddof, axis=axis, keepdims=keepdims, where=where
)

check_op(op_np, op_num, np_in, dtype)


@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("ddof", [0, 1])
@pytest.mark.parametrize("axis", [None, 0, 1, 2])
Expand Down

0 comments on commit bdb49e9

Please sign in to comment.