Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support np.select operation #1066

Merged
merged 18 commits into from
Dec 7, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions cunumeric/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ class _CunumericSharedLib:
CUNUMERIC_SCAN_PROD: int
CUNUMERIC_SCAN_SUM: int
CUNUMERIC_SEARCHSORTED: int
CUNUMERIC_SELECT: int
CUNUMERIC_SOLVE: int
CUNUMERIC_SORT: int
CUNUMERIC_SYRK: int
Expand Down Expand Up @@ -365,6 +366,7 @@ class CuNumericOpCode(IntEnum):
SCAN_GLOBAL = _cunumeric.CUNUMERIC_SCAN_GLOBAL
SCAN_LOCAL = _cunumeric.CUNUMERIC_SCAN_LOCAL
SEARCHSORTED = _cunumeric.CUNUMERIC_SEARCHSORTED
SELECT = _cunumeric.CUNUMERIC_SELECT
SOLVE = _cunumeric.CUNUMERIC_SOLVE
SORT = _cunumeric.CUNUMERIC_SORT
SYRK = _cunumeric.CUNUMERIC_SYRK
Expand Down
75 changes: 53 additions & 22 deletions cunumeric/deferred.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from enum import IntEnum, unique
from functools import reduce, wraps
from inspect import signature
from itertools import product
from itertools import chain, product
from typing import (
TYPE_CHECKING,
Any,
Expand All @@ -36,6 +36,7 @@
import legate.core.types as ty
import numpy as np
from legate.core import Annotation, Future, ReductionOp, Store
from legate.core.store import RegionField
from legate.core.utils import OrderedSet
from numpy.core.numeric import ( # type: ignore [attr-defined]
normalize_axis_tuple,
Expand All @@ -57,7 +58,7 @@
from .linalg.solve import solve
from .sort import sort
from .thunk import NumPyThunk
from .utils import is_advanced_indexing
from .utils import is_advanced_indexing, to_core_dtype

if TYPE_CHECKING:
import numpy.typing as npt
Expand Down Expand Up @@ -261,7 +262,7 @@ def __init__(
super().__init__(runtime, base.type.to_numpy_dtype())
assert base is not None
assert isinstance(base, Store)
self.base: Any = base # a Legate Store
self.base = base # a Legate Store
self.numpy_array = (
None if numpy_array is None else weakref.ref(numpy_array)
)
Expand All @@ -270,11 +271,13 @@ def __str__(self) -> str:
return f"DeferredArray(base: {self.base})"

@property
def storage(self) -> Union[Future, tuple[Region, FieldID]]:
def storage(self) -> Union[Future, tuple[Region, Union[int, FieldID]]]:
storage = self.base.storage
if self.base.kind == Future:
assert isinstance(storage, Future)
return storage
else:
assert isinstance(storage, RegionField)
return (storage.region, storage.field.field_id)

@property
Expand Down Expand Up @@ -402,6 +405,7 @@ def scalar(self) -> bool:

def get_scalar_array(self) -> npt.NDArray[Any]:
assert self.scalar
assert isinstance(self.base.storage, Future)
buf = self.base.storage.get_buffer(self.dtype.itemsize)
result = np.frombuffer(buf, dtype=self.dtype, count=1)
return result.reshape(())
Expand Down Expand Up @@ -770,10 +774,13 @@ def _create_indexing_array(

store = self.base
rhs = self
computed_key: tuple[Any, ...]
if isinstance(key, NumPyThunk):
key = (key,)
assert isinstance(key, tuple)
key = self._unpack_ellipsis(key, self.ndim)
computed_key = (key,)
else:
computed_key = key
assert isinstance(computed_key, tuple)
computed_key = self._unpack_ellipsis(computed_key, self.ndim)

# the index where the first index_array is passed to the [] operator
start_index = -1
Expand All @@ -788,7 +795,7 @@ def _create_indexing_array(
tuple_of_arrays: tuple[Any, ...] = ()

# First, we need to check if transpose is needed
for dim, k in enumerate(key):
for dim, k in enumerate(computed_key):
if np.isscalar(k) or isinstance(k, NumPyThunk):
if start_index == -1:
start_index = dim
Expand All @@ -813,25 +820,29 @@ def _create_indexing_array(
)
transpose_indices += post_indices
post_indices = tuple(
i for i in range(len(key)) if i not in key_transpose_indices
i
for i in range(len(computed_key))
if i not in key_transpose_indices
)
key_transpose_indices += post_indices
store = store.transpose(transpose_indices)
key = tuple(key[i] for i in key_transpose_indices)
computed_key = tuple(
computed_key[i] for i in key_transpose_indices
)

shift = 0
for dim, k in enumerate(key):
for dim, k in enumerate(computed_key):
if np.isscalar(k):
if k < 0: # type: ignore [operator]
k += store.shape[dim + shift]
k += store.shape[dim + shift] # type: ignore [operator]
store = store.project(dim + shift, k)
shift -= 1
elif k is np.newaxis:
store = store.promote(dim + shift, 1)
elif isinstance(k, slice):
k, store = self._slice_store(k, store, dim + shift)
elif isinstance(k, NumPyThunk):
if not isinstance(key, DeferredArray):
if not isinstance(computed_key, DeferredArray):
k = self.runtime.to_deferred_array(k)
if k.dtype == bool:
for i in range(k.ndim):
Expand Down Expand Up @@ -900,7 +911,7 @@ def _get_view(self, key: Any) -> DeferredArray:
k, store = self._slice_store(k, store, dim + shift)
elif np.isscalar(k):
if k < 0: # type: ignore [operator]
k += store.shape[dim + shift]
k += store.shape[dim + shift] # type: ignore [operator]
store = store.project(dim + shift, k)
shift -= 1
else:
Expand Down Expand Up @@ -1329,10 +1340,9 @@ def swapaxes(self, axis1: int, axis2: int) -> DeferredArray:
dims = list(range(self.ndim))
dims[axis1], dims[axis2] = dims[axis2], dims[axis1]

result = self.base.transpose(dims)
result = DeferredArray(self.runtime, result)
result = self.base.transpose(tuple(dims))

return result
return DeferredArray(self.runtime, result)

# Convert the source array to the destination array
@auto_convert("rhs")
Expand Down Expand Up @@ -1738,8 +1748,8 @@ def choose(self, rhs: Any, *args: Any) -> None:

out_arr = self.base
# broadcast input array and all choices arrays to the same shape
index = index_arr._broadcast(out_arr.shape)
ch_tuple = tuple(c._broadcast(out_arr.shape) for c in ch_def)
index = index_arr._broadcast(out_arr.shape.extents)
ch_tuple = tuple(c._broadcast(out_arr.shape.extents) for c in ch_def)

task = self.context.create_auto_task(CuNumericOpCode.CHOOSE)
task.add_output(out_arr)
Expand All @@ -1752,6 +1762,27 @@ def choose(self, rhs: Any, *args: Any) -> None:
task.add_alignment(index, c)
task.execute()

def select(
self,
condlist: Iterable[Any],
choicelist: Iterable[Any],
default: npt.NDArray[Any],
) -> None:
condlist_ = tuple(self.runtime.to_deferred_array(c) for c in condlist)
choicelist_ = tuple(
self.runtime.to_deferred_array(c) for c in choicelist
)

task = self.context.create_auto_task(CuNumericOpCode.SELECT)
out_arr = self.base
task.add_output(out_arr)
for c in chain(condlist_, choicelist_):
c_arr = c._broadcast(self.shape)
task.add_input(c_arr)
task.add_alignment(c_arr, out_arr)
task.add_scalar_arg(default, to_core_dtype(default.dtype))
task.execute()

# Create or extract a diagonal from a matrix
@auto_convert("rhs")
def _diag_helper(
Expand Down Expand Up @@ -1964,9 +1995,9 @@ def tile(self, rhs: Any, reps: Union[Any, Sequence[int]]) -> None:
def transpose(
self, axes: Union[None, tuple[int, ...], list[int]]
) -> DeferredArray:
result = self.base.transpose(axes)
result = DeferredArray(self.runtime, result)
return result
computed_axes = tuple(axes) if axes is not None else ()
result = self.base.transpose(computed_axes)
return DeferredArray(self.runtime, result)

@auto_convert("rhs")
def trilu(self, rhs: Any, k: int, lower: bool) -> None:
Expand Down
23 changes: 22 additions & 1 deletion cunumeric/eager.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
Any,
Callable,
Dict,
Iterable,
Optional,
Sequence,
Union,
Expand Down Expand Up @@ -234,7 +235,7 @@ def __init__(
self.escaped = False

@property
def storage(self) -> Union[Future, tuple[Region, FieldID]]:
def storage(self) -> Union[Future, tuple[Region, Union[int, FieldID]]]:
if self.deferred is None:
self.to_deferred_array()

Expand Down Expand Up @@ -629,6 +630,26 @@ def choose(self, rhs: Any, *args: Any) -> None:
choices = tuple(c.array for c in args)
self.array[:] = np.choose(rhs.array, choices, mode="raise")

def select(
self,
condlist: Iterable[Any],
choicelist: Iterable[Any],
default: npt.NDArray[Any],
) -> None:
self.check_eager_args(*condlist, *choicelist)
if self.deferred is not None:
self.deferred.select(
condlist,
choicelist,
default,
)
else:
self.array[...] = np.select(
tuple(c.array for c in condlist),
tuple(c.array for c in choicelist),
default,
)

def _diag_helper(
self, rhs: Any, offset: int, naxes: int, extract: bool, trace: bool
) -> None:
Expand Down
73 changes: 72 additions & 1 deletion cunumeric/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def zeros_like(

def full(
shape: NdShapeLike,
value: Union[int, float],
value: Any,
dtype: Optional[npt.DTypeLike] = None,
) -> ndarray:
"""
Expand Down Expand Up @@ -3743,6 +3743,77 @@ def choose(
return a.choose(choices=choices, out=out, mode=mode)


def select(
condlist: Sequence[npt.ArrayLike | ndarray],
choicelist: Sequence[npt.ArrayLike | ndarray],
default: Any = 0,
) -> ndarray:
"""
Return an array drawn from elements in choicelist, depending on conditions.

Parameters
----------
condlist : list of bool ndarrays
The list of conditions which determine from which array in `choicelist`
the output elements are taken. When multiple conditions are satisfied,
the first one encountered in `condlist` is used.
choicelist : list of ndarrays
The list of arrays from which the output elements are taken. It has
to be of the same length as `condlist`.
default : scalar, optional
The element inserted in `output` when all conditions evaluate to False.

Returns
-------
output : ndarray
The output at position m is the m-th element of the array in
`choicelist` where the m-th element of the corresponding array in
`condlist` is True.

See Also
--------
numpy.select

Availability
--------
Multiple GPUs, Multiple CPUs
"""

if len(condlist) != len(choicelist):
raise ValueError(
"list of cases must be same length as list of conditions"
)
if len(condlist) == 0:
raise ValueError("select with an empty condition list is not possible")

condlist_ = tuple(convert_to_cunumeric_ndarray(c) for c in condlist)
for i, c in enumerate(condlist_):
if c.dtype != bool:
raise TypeError(
f"invalid entry {i} in condlist: should be boolean ndarray"
)

choicelist_ = tuple(convert_to_cunumeric_ndarray(c) for c in choicelist)
common_type = np.result_type(*choicelist_, default)
args = condlist_ + choicelist_
choicelist_ = tuple(
c._maybe_convert(common_type, args) for c in choicelist_
)
default_ = np.array(default, dtype=common_type)

out_shape = np.broadcast_shapes(
*(c.shape for c in condlist_),
*(c.shape for c in choicelist_),
)
out = ndarray(shape=out_shape, dtype=common_type, inputs=args)
out._thunk.select(
tuple(c._thunk for c in condlist_),
tuple(c._thunk for c in choicelist_),
default_,
)
return out


@add_boilerplate("condition", "a")
def compress(
condition: ndarray,
Expand Down
13 changes: 11 additions & 2 deletions cunumeric/thunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from __future__ import annotations

from abc import ABC, abstractmethod, abstractproperty
from typing import TYPE_CHECKING, Any, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any, Iterable, Optional, Sequence, Union

from .config import ConvertCode

Expand Down Expand Up @@ -74,7 +74,7 @@ def size(self) -> int:
# Abstract methods

@abstractproperty
def storage(self) -> Union[Future, tuple[Region, FieldID]]:
def storage(self) -> Union[Future, tuple[Region, Union[int, FieldID]]]:
"""Return the Legion storage primitive for this NumPy thunk"""
...

Expand Down Expand Up @@ -191,6 +191,15 @@ def contract(
def choose(self, rhs: Any, *args: Any) -> None:
...

@abstractmethod
def select(
self,
condlist: Iterable[Any],
choicelist: Iterable[Any],
default: npt.NDArray[Any],
) -> None:
...

@abstractmethod
def _diag_helper(
self, rhs: Any, offset: int, naxes: int, extract: bool, trace: bool
Expand Down
Loading