Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Supporting "where" for unary operations #1061

Merged
merged 68 commits into from
Dec 3, 2023
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
Show all changes
68 commits
Select commit Hold shift + click to select a range
f0eef07
adding support for adding where to unary reduction operations on the …
ipdemes Aug 29, 2023
a93748d
adding support for adding where to unary reduction operations on the …
ipdemes Aug 30, 2023
79c7aca
formatting
ipdemes Aug 30, 2023
a64e829
fixing error when calculation divisor
ipdemes Aug 30, 2023
6f7b059
fixing error with mean(where)
ipdemes Aug 31, 2023
d12fe55
updating test
ipdemes Sep 1, 2023
ec41b52
adding logic for where in unary_red on C++
ipdemes Sep 1, 2023
9128530
towards improving data usage by mean
ipdemes Sep 2, 2023
76dc871
debugging
ipdemes Sep 3, 2023
cbe93d9
removing debug output
ipdemes Sep 3, 2023
419a5af
adding support for broadcasting where + initial implementation of nan…
ipdemes Oct 11, 2023
0d7d66c
adding test for nanmean
ipdemes Oct 17, 2023
baf61ed
some clean-up + formatting
ipdemes Oct 17, 2023
30b1e18
fixing issues after rebase
ipdemes Oct 17, 2023
2277234
code clean-up
ipdemes Oct 17, 2023
73e0609
code clean-up
ipdemes Oct 18, 2023
dfafdcc
adding test_where for unary operations
ipdemes Oct 18, 2023
3071edd
adding nanmean to docs
ipdemes Oct 18, 2023
67aab76
nanmean is not oficiall method of ndarray
ipdemes Oct 18, 2023
09a4477
fixing test
ipdemes Oct 18, 2023
851dd42
Update src/cunumeric/unary/scalar_unary_red_template.inl
ipdemes Nov 9, 2023
83308a6
Update src/cunumeric/unary/scalar_unary_red_template.inl
ipdemes Nov 9, 2023
df70b28
Update src/cunumeric/unary/scalar_unary_red_template.inl
ipdemes Nov 9, 2023
7070d9f
Update src/cunumeric/unary/scalar_unary_red_template.inl
ipdemes Nov 9, 2023
c00661a
Update cunumeric/array.py
ipdemes Nov 9, 2023
ab3bb10
Update src/cunumeric/unary/scalar_unary_red_template.inl
ipdemes Nov 9, 2023
df5248e
Update src/cunumeric/unary/unary_red.cu
ipdemes Nov 9, 2023
8b3d0b7
Update cunumeric/array.py
ipdemes Nov 9, 2023
2a2ead5
Update cunumeric/array.py
ipdemes Nov 9, 2023
329bf5a
Update cunumeric/array.py
ipdemes Nov 9, 2023
dcce884
Update cunumeric/eager.py
ipdemes Nov 9, 2023
a50f165
Update cunumeric/module.py
ipdemes Nov 9, 2023
c52e0c5
Update cunumeric/array.py
ipdemes Nov 9, 2023
f7235b1
Update cunumeric/module.py
ipdemes Nov 9, 2023
568dfc5
updating documentation + fixing mypy issue
ipdemes Nov 9, 2023
67abe0c
cleaning up some of the C++ code
ipdemes Nov 9, 2023
bb8a462
Update cunumeric/array.py
ipdemes Nov 9, 2023
186af6f
fixing the logic for nanmean after previous commit
ipdemes Nov 13, 2023
2729b55
claning up scalar_unary_red_template.inl
ipdemes Nov 13, 2023
e675374
claning up unary_red_* files
ipdemes Nov 13, 2023
c17f3d2
removing convert_to_predicate_ndarray
ipdemes Nov 13, 2023
e7d9096
small clean-up
ipdemes Nov 13, 2023
b6a90c9
Update cunumeric/array.py
ipdemes Nov 13, 2023
7ec317c
removing where from perform_unary_op
ipdemes Nov 13, 2023
4709638
fixing logic for nanmin/nanmax for where
ipdemes Nov 13, 2023
95bfbeb
addressing the rest of the comments from Manolis
ipdemes Nov 14, 2023
90f27bc
some C++ code clean-up
ipdemes Nov 14, 2023
fa19be7
more C++ code clean-up
ipdemes Nov 14, 2023
a51135d
Update src/cunumeric/unary/unary_red.cu
ipdemes Nov 17, 2023
e703cc9
Update src/cunumeric/unary/unary_red.cc
ipdemes Nov 17, 2023
7ea110f
Update src/cunumeric/unary/unary_red_omp.cc
ipdemes Nov 17, 2023
f080a95
Update src/cunumeric/unary/scalar_unary_red_template.inl
ipdemes Nov 17, 2023
c039cb2
Update src/cunumeric/unary/scalar_unary_red_template.inl
ipdemes Nov 17, 2023
6896b4e
Update src/cunumeric/unary/unary_red_template.inl
ipdemes Nov 17, 2023
4014239
Update src/cunumeric/unary/scalar_unary_red_template.inl
ipdemes Nov 17, 2023
fe84bb1
Update src/cunumeric/unary/unary_red_template.inl
ipdemes Nov 17, 2023
70fa320
Update src/cunumeric/unary/unary_red.cu
ipdemes Nov 17, 2023
9d5b364
Update cunumeric/array.py
ipdemes Nov 17, 2023
2bd62eb
Update cunumeric/array.py
ipdemes Nov 17, 2023
f504f05
fixing bug in cuda code
ipdemes Nov 17, 2023
b15f656
adding dtype back to sum
ipdemes Nov 17, 2023
7a77535
removing ignore_nan
ipdemes Nov 17, 2023
b39364b
Merge remote-tracking branch 'origin/branch-24.01' into unary_red_where
ipdemes Nov 21, 2023
bdb49e9
fixing var for the case when where is not None
ipdemes Nov 30, 2023
4ef0749
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 30, 2023
18a6fe8
fixing eager logic for var
ipdemes Nov 30, 2023
7c1fdaa
Merge remote-tracking branch 'origin/branch-24.01' into unary_red_where
ipdemes Nov 30, 2023
869b176
removing dtype from _normalize_summation
ipdemes Dec 2, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 24 additions & 7 deletions cunumeric/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -3119,16 +3119,25 @@ def _normalize_summation(
divisor = reduce(lambda x, y: x * y, self.shape, 1) - ddof
else:
if where is not None:
divisor = (
where.sum(axis=axis, dtype=dtype, keepdims=keepdims) - ddof
divisor = where.sum(
axis=axis, dtype=sum_array.dtype, keepdims=keepdims
)
if ddof != 0 and not np.isscalar(divisor):
mask = divisor != 0
values = divisor - ddof
divisor._thunk.putmask(mask._thunk, values._thunk)
else:
divisor -= ddof
else:
divisor = self.shape[axis] - ddof

# Divide by the number of things in the collapsed dimensions
# Pick the right kinds of division based on the dtype
if np.ndim(divisor) == 0:
if isinstance(divisor, ndarray):
divisor = divisor.astype(sum_array.dtype)
else:
divisor = np.array(divisor, dtype=sum_array.dtype) # type: ignore [assignment] # noqa

if dtype.kind == "f" or dtype.kind == "c":
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can just get rid of the dtype parameter at this point, and just use sum_array.dtype everywhere in this function.

sum_array.__itruediv__(divisor)
else:
Expand Down Expand Up @@ -3267,8 +3276,9 @@ def var(
# mean can be broadcast against the original array
mu = self.mean(axis=axis, dtype=dtype, keepdims=True, where=where)

where_array = broadcast_where(where, self.shape)

# 1D arrays (or equivalent) should benefit from this unary reduction:
#
if axis is None or calculate_volume(tuple_pop(self.shape, axis)) == 1:
# this is a scalar reduction and we can optimize this as a single
# pass through a scalar reduction
Expand All @@ -3279,7 +3289,7 @@ def var(
dtype=dtype,
out=out,
keepdims=keepdims,
where=where,
where=where_array,
args=(mu,),
)
else:
Expand All @@ -3300,10 +3310,17 @@ def var(
dtype=dtype,
out=out,
keepdims=keepdims,
where=where,
where=where_array,
)

self._normalize_summation(result, axis=axis, dtype=dtype, ddof=ddof)
self._normalize_summation(
result,
axis=axis,
dtype=result.dtype,
ddof=ddof,
keepdims=keepdims,
where=where_array,
)

return result

Expand Down
20 changes: 19 additions & 1 deletion tests/integration/test_stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ def check_result(in_np, out_np, out_num, **isclose_kwargs):
is_negative_test = False

result = (
allclose(out_np, out_num, **isclose_kwargs)
allclose(out_np, out_num, equal_nan=True, **isclose_kwargs)
and out_np.dtype == out_num.dtype
)
if not result and not is_negative_test:
Expand Down Expand Up @@ -131,6 +131,24 @@ def test_var_default_shape(dtype, ddof, axis, keepdims):
check_op(op_np, op_num, np_in, dtype)


@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("ddof", [0, 1])
@pytest.mark.parametrize("axis", [None, 0, 1])
@pytest.mark.parametrize("keepdims", [False, True])
def test_var_where(dtype, ddof, axis, keepdims):
np_in = get_op_input(astype=dtype)
where = (np_in.astype(int) % 2).astype(bool)

op_np = functools.partial(
np.var, ddof=ddof, axis=axis, keepdims=keepdims, where=where
)
op_num = functools.partial(
num.var, ddof=ddof, axis=axis, keepdims=keepdims, where=where
)

check_op(op_np, op_num, np_in, dtype)


@pytest.mark.parametrize("dtype", dtypes)
@pytest.mark.parametrize("ddof", [0, 1])
@pytest.mark.parametrize("axis", [None, 0, 1, 2])
Expand Down