Skip to content

Commit

Permalink
Supporting "where" for unary operations (#1061)
Browse files Browse the repository at this point in the history
  • Loading branch information
ipdemes authored Dec 3, 2023
1 parent 03cb875 commit 672e8c6
Show file tree
Hide file tree
Showing 24 changed files with 740 additions and 238 deletions.
20 changes: 10 additions & 10 deletions cunumeric/_ufunc/ufunc.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2021-2022 NVIDIA Corporation
# Copyright 2021-2023 NVIDIA Corporation
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -14,12 +14,17 @@
#
from __future__ import annotations

from typing import TYPE_CHECKING, Any, Dict, Sequence, Union
from typing import TYPE_CHECKING, Any, Dict, Optional, Sequence, Union

import numpy as np
from legate.core.utils import OrderedSet

from ..array import check_writeable, convert_to_cunumeric_ndarray, ndarray
from ..array import (
add_boilerplate,
check_writeable,
convert_to_cunumeric_ndarray,
ndarray,
)
from ..config import BinaryOpCode, UnaryOpCode, UnaryRedCode
from ..types import NdShape

Expand Down Expand Up @@ -680,6 +685,7 @@ def __call__(

return self._maybe_cast_output(out, result)

@add_boilerplate("array")
def reduce(
self,
array: ndarray,
Expand All @@ -688,7 +694,7 @@ def reduce(
out: Union[ndarray, None] = None,
keepdims: bool = False,
initial: Union[Any, None] = None,
where: bool = True,
where: Optional[ndarray] = None,
) -> ndarray:
"""
reduce(array, axis=0, dtype=None, out=None, keepdims=False, initial=<no
Expand Down Expand Up @@ -742,16 +748,10 @@ def reduce(
--------
numpy.ufunc.reduce
"""
array = convert_to_cunumeric_ndarray(array)

if self._red_code is None:
raise NotImplementedError(
f"reduction for {self} is not yet implemented"
)
if not isinstance(where, bool) or not where:
raise NotImplementedError(
"the 'where' keyword is not yet supported"
)

# NumPy seems to be using None as the default axis value for scalars
if array.ndim == 0 and axis == 0:
Expand Down
Loading

0 comments on commit 672e8c6

Please sign in to comment.