diff --git a/cunumeric/config.py b/cunumeric/config.py index 635544bd8..c18d36f4b 100644 --- a/cunumeric/config.py +++ b/cunumeric/config.py @@ -197,6 +197,7 @@ class _CunumericSharedLib: CUNUMERIC_SCAN_PROD: int CUNUMERIC_SCAN_SUM: int CUNUMERIC_SEARCHSORTED: int + CUNUMERIC_SELECT: int CUNUMERIC_SOLVE: int CUNUMERIC_SORT: int CUNUMERIC_SYRK: int @@ -365,6 +366,7 @@ class CuNumericOpCode(IntEnum): SCAN_GLOBAL = _cunumeric.CUNUMERIC_SCAN_GLOBAL SCAN_LOCAL = _cunumeric.CUNUMERIC_SCAN_LOCAL SEARCHSORTED = _cunumeric.CUNUMERIC_SEARCHSORTED + SELECT = _cunumeric.CUNUMERIC_SELECT SOLVE = _cunumeric.CUNUMERIC_SOLVE SORT = _cunumeric.CUNUMERIC_SORT SYRK = _cunumeric.CUNUMERIC_SYRK diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index fb288a205..9f9e46057 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -20,7 +20,7 @@ from enum import IntEnum, unique from functools import reduce, wraps from inspect import signature -from itertools import product +from itertools import chain, product from typing import ( TYPE_CHECKING, Any, @@ -36,6 +36,7 @@ import legate.core.types as ty import numpy as np from legate.core import Annotation, Future, ReductionOp, Store +from legate.core.store import RegionField from legate.core.utils import OrderedSet from numpy.core.numeric import ( # type: ignore [attr-defined] normalize_axis_tuple, @@ -57,7 +58,7 @@ from .linalg.solve import solve from .sort import sort from .thunk import NumPyThunk -from .utils import is_advanced_indexing +from .utils import is_advanced_indexing, to_core_dtype if TYPE_CHECKING: import numpy.typing as npt @@ -261,7 +262,7 @@ def __init__( super().__init__(runtime, base.type.to_numpy_dtype()) assert base is not None assert isinstance(base, Store) - self.base: Any = base # a Legate Store + self.base = base # a Legate Store self.numpy_array = ( None if numpy_array is None else weakref.ref(numpy_array) ) @@ -270,11 +271,13 @@ def __str__(self) -> str: return f"DeferredArray(base: {self.base})" @property - def storage(self) -> Union[Future, tuple[Region, FieldID]]: + def storage(self) -> Union[Future, tuple[Region, Union[int, FieldID]]]: storage = self.base.storage if self.base.kind == Future: + assert isinstance(storage, Future) return storage else: + assert isinstance(storage, RegionField) return (storage.region, storage.field.field_id) @property @@ -402,6 +405,7 @@ def scalar(self) -> bool: def get_scalar_array(self) -> npt.NDArray[Any]: assert self.scalar + assert isinstance(self.base.storage, Future) buf = self.base.storage.get_buffer(self.dtype.itemsize) result = np.frombuffer(buf, dtype=self.dtype, count=1) return result.reshape(()) @@ -770,10 +774,13 @@ def _create_indexing_array( store = self.base rhs = self + computed_key: tuple[Any, ...] if isinstance(key, NumPyThunk): - key = (key,) - assert isinstance(key, tuple) - key = self._unpack_ellipsis(key, self.ndim) + computed_key = (key,) + else: + computed_key = key + assert isinstance(computed_key, tuple) + computed_key = self._unpack_ellipsis(computed_key, self.ndim) # the index where the first index_array is passed to the [] operator start_index = -1 @@ -788,7 +795,7 @@ def _create_indexing_array( tuple_of_arrays: tuple[Any, ...] = () # First, we need to check if transpose is needed - for dim, k in enumerate(key): + for dim, k in enumerate(computed_key): if np.isscalar(k) or isinstance(k, NumPyThunk): if start_index == -1: start_index = dim @@ -813,17 +820,21 @@ def _create_indexing_array( ) transpose_indices += post_indices post_indices = tuple( - i for i in range(len(key)) if i not in key_transpose_indices + i + for i in range(len(computed_key)) + if i not in key_transpose_indices ) key_transpose_indices += post_indices store = store.transpose(transpose_indices) - key = tuple(key[i] for i in key_transpose_indices) + computed_key = tuple( + computed_key[i] for i in key_transpose_indices + ) shift = 0 - for dim, k in enumerate(key): + for dim, k in enumerate(computed_key): if np.isscalar(k): if k < 0: # type: ignore [operator] - k += store.shape[dim + shift] + k += store.shape[dim + shift] # type: ignore [operator] store = store.project(dim + shift, k) shift -= 1 elif k is np.newaxis: @@ -831,7 +842,7 @@ def _create_indexing_array( elif isinstance(k, slice): k, store = self._slice_store(k, store, dim + shift) elif isinstance(k, NumPyThunk): - if not isinstance(key, DeferredArray): + if not isinstance(computed_key, DeferredArray): k = self.runtime.to_deferred_array(k) if k.dtype == bool: for i in range(k.ndim): @@ -900,7 +911,7 @@ def _get_view(self, key: Any) -> DeferredArray: k, store = self._slice_store(k, store, dim + shift) elif np.isscalar(k): if k < 0: # type: ignore [operator] - k += store.shape[dim + shift] + k += store.shape[dim + shift] # type: ignore [operator] store = store.project(dim + shift, k) shift -= 1 else: @@ -1329,10 +1340,9 @@ def swapaxes(self, axis1: int, axis2: int) -> DeferredArray: dims = list(range(self.ndim)) dims[axis1], dims[axis2] = dims[axis2], dims[axis1] - result = self.base.transpose(dims) - result = DeferredArray(self.runtime, result) + result = self.base.transpose(tuple(dims)) - return result + return DeferredArray(self.runtime, result) # Convert the source array to the destination array @auto_convert("rhs") @@ -1738,8 +1748,8 @@ def choose(self, rhs: Any, *args: Any) -> None: out_arr = self.base # broadcast input array and all choices arrays to the same shape - index = index_arr._broadcast(out_arr.shape) - ch_tuple = tuple(c._broadcast(out_arr.shape) for c in ch_def) + index = index_arr._broadcast(out_arr.shape.extents) + ch_tuple = tuple(c._broadcast(out_arr.shape.extents) for c in ch_def) task = self.context.create_auto_task(CuNumericOpCode.CHOOSE) task.add_output(out_arr) @@ -1752,6 +1762,27 @@ def choose(self, rhs: Any, *args: Any) -> None: task.add_alignment(index, c) task.execute() + def select( + self, + condlist: Iterable[Any], + choicelist: Iterable[Any], + default: npt.NDArray[Any], + ) -> None: + condlist_ = tuple(self.runtime.to_deferred_array(c) for c in condlist) + choicelist_ = tuple( + self.runtime.to_deferred_array(c) for c in choicelist + ) + + task = self.context.create_auto_task(CuNumericOpCode.SELECT) + out_arr = self.base + task.add_output(out_arr) + for c in chain(condlist_, choicelist_): + c_arr = c._broadcast(self.shape) + task.add_input(c_arr) + task.add_alignment(c_arr, out_arr) + task.add_scalar_arg(default, to_core_dtype(default.dtype)) + task.execute() + # Create or extract a diagonal from a matrix @auto_convert("rhs") def _diag_helper( @@ -1964,9 +1995,9 @@ def tile(self, rhs: Any, reps: Union[Any, Sequence[int]]) -> None: def transpose( self, axes: Union[None, tuple[int, ...], list[int]] ) -> DeferredArray: - result = self.base.transpose(axes) - result = DeferredArray(self.runtime, result) - return result + computed_axes = tuple(axes) if axes is not None else () + result = self.base.transpose(computed_axes) + return DeferredArray(self.runtime, result) @auto_convert("rhs") def trilu(self, rhs: Any, k: int, lower: bool) -> None: diff --git a/cunumeric/eager.py b/cunumeric/eager.py index 7c58023ef..4e6e504c2 100644 --- a/cunumeric/eager.py +++ b/cunumeric/eager.py @@ -19,6 +19,7 @@ Any, Callable, Dict, + Iterable, Optional, Sequence, Union, @@ -234,7 +235,7 @@ def __init__( self.escaped = False @property - def storage(self) -> Union[Future, tuple[Region, FieldID]]: + def storage(self) -> Union[Future, tuple[Region, Union[int, FieldID]]]: if self.deferred is None: self.to_deferred_array() @@ -629,6 +630,26 @@ def choose(self, rhs: Any, *args: Any) -> None: choices = tuple(c.array for c in args) self.array[:] = np.choose(rhs.array, choices, mode="raise") + def select( + self, + condlist: Iterable[Any], + choicelist: Iterable[Any], + default: npt.NDArray[Any], + ) -> None: + self.check_eager_args(*condlist, *choicelist) + if self.deferred is not None: + self.deferred.select( + condlist, + choicelist, + default, + ) + else: + self.array[...] = np.select( + tuple(c.array for c in condlist), + tuple(c.array for c in choicelist), + default, + ) + def _diag_helper( self, rhs: Any, offset: int, naxes: int, extract: bool, trace: bool ) -> None: diff --git a/cunumeric/module.py b/cunumeric/module.py index 45d7508b1..0a8132a81 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -378,7 +378,7 @@ def zeros_like( def full( shape: NdShapeLike, - value: Union[int, float], + value: Any, dtype: Optional[npt.DTypeLike] = None, ) -> ndarray: """ @@ -3743,6 +3743,77 @@ def choose( return a.choose(choices=choices, out=out, mode=mode) +def select( + condlist: Sequence[npt.ArrayLike | ndarray], + choicelist: Sequence[npt.ArrayLike | ndarray], + default: Any = 0, +) -> ndarray: + """ + Return an array drawn from elements in choicelist, depending on conditions. + + Parameters + ---------- + condlist : list of bool ndarrays + The list of conditions which determine from which array in `choicelist` + the output elements are taken. When multiple conditions are satisfied, + the first one encountered in `condlist` is used. + choicelist : list of ndarrays + The list of arrays from which the output elements are taken. It has + to be of the same length as `condlist`. + default : scalar, optional + The element inserted in `output` when all conditions evaluate to False. + + Returns + ------- + output : ndarray + The output at position m is the m-th element of the array in + `choicelist` where the m-th element of the corresponding array in + `condlist` is True. + + See Also + -------- + numpy.select + + Availability + -------- + Multiple GPUs, Multiple CPUs + """ + + if len(condlist) != len(choicelist): + raise ValueError( + "list of cases must be same length as list of conditions" + ) + if len(condlist) == 0: + raise ValueError("select with an empty condition list is not possible") + + condlist_ = tuple(convert_to_cunumeric_ndarray(c) for c in condlist) + for i, c in enumerate(condlist_): + if c.dtype != bool: + raise TypeError( + f"invalid entry {i} in condlist: should be boolean ndarray" + ) + + choicelist_ = tuple(convert_to_cunumeric_ndarray(c) for c in choicelist) + common_type = np.result_type(*choicelist_, default) + args = condlist_ + choicelist_ + choicelist_ = tuple( + c._maybe_convert(common_type, args) for c in choicelist_ + ) + default_ = np.array(default, dtype=common_type) + + out_shape = np.broadcast_shapes( + *(c.shape for c in condlist_), + *(c.shape for c in choicelist_), + ) + out = ndarray(shape=out_shape, dtype=common_type, inputs=args) + out._thunk.select( + tuple(c._thunk for c in condlist_), + tuple(c._thunk for c in choicelist_), + default_, + ) + return out + + @add_boilerplate("condition", "a") def compress( condition: ndarray, diff --git a/cunumeric/thunk.py b/cunumeric/thunk.py index 62417271a..68aafb6c9 100644 --- a/cunumeric/thunk.py +++ b/cunumeric/thunk.py @@ -15,7 +15,7 @@ from __future__ import annotations from abc import ABC, abstractmethod, abstractproperty -from typing import TYPE_CHECKING, Any, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Iterable, Optional, Sequence, Union from .config import ConvertCode @@ -74,7 +74,7 @@ def size(self) -> int: # Abstract methods @abstractproperty - def storage(self) -> Union[Future, tuple[Region, FieldID]]: + def storage(self) -> Union[Future, tuple[Region, Union[int, FieldID]]]: """Return the Legion storage primitive for this NumPy thunk""" ... @@ -191,6 +191,15 @@ def contract( def choose(self, rhs: Any, *args: Any) -> None: ... + @abstractmethod + def select( + self, + condlist: Iterable[Any], + choicelist: Iterable[Any], + default: npt.NDArray[Any], + ) -> None: + ... + @abstractmethod def _diag_helper( self, rhs: Any, offset: int, naxes: int, extract: bool, trace: bool diff --git a/cunumeric_cpp.cmake b/cunumeric_cpp.cmake index f7feee620..be5c0fbe6 100644 --- a/cunumeric_cpp.cmake +++ b/cunumeric_cpp.cmake @@ -137,10 +137,11 @@ list(APPEND cunumeric_SOURCES src/cunumeric/nullary/window.cc src/cunumeric/index/advanced_indexing.cc src/cunumeric/index/choose.cc + src/cunumeric/index/putmask.cc src/cunumeric/index/repeat.cc + src/cunumeric/index/select.cc src/cunumeric/index/wrap.cc src/cunumeric/index/zip.cc - src/cunumeric/index/putmask.cc src/cunumeric/item/read.cc src/cunumeric/item/write.cc src/cunumeric/matrix/batched_cholesky.cc @@ -194,6 +195,7 @@ if(Legion_USE_OpenMP) src/cunumeric/index/choose_omp.cc src/cunumeric/index/putmask_omp.cc src/cunumeric/index/repeat_omp.cc + src/cunumeric/index/select_omp.cc src/cunumeric/index/wrap_omp.cc src/cunumeric/index/zip_omp.cc src/cunumeric/matrix/batched_cholesky_omp.cc @@ -241,10 +243,11 @@ if(Legion_USE_CUDA) src/cunumeric/nullary/window.cu src/cunumeric/index/advanced_indexing.cu src/cunumeric/index/choose.cu + src/cunumeric/index/putmask.cu src/cunumeric/index/repeat.cu + src/cunumeric/index/select.cu src/cunumeric/index/wrap.cu src/cunumeric/index/zip.cu - src/cunumeric/index/putmask.cu src/cunumeric/item/read.cu src/cunumeric/item/write.cu src/cunumeric/matrix/batched_cholesky.cu diff --git a/docs/cunumeric/source/api/indexing.rst b/docs/cunumeric/source/api/indexing.rst index ab02bbcc4..e1a3358cb 100644 --- a/docs/cunumeric/source/api/indexing.rst +++ b/docs/cunumeric/source/api/indexing.rst @@ -32,6 +32,7 @@ Indexing-like operations compress diag diagonal + select take take_along_axis @@ -41,7 +42,7 @@ Inserting data into arrays .. autosummary:: :toctree: generated/ - + fill_diagonal put putmask diff --git a/src/cunumeric/cunumeric_c.h b/src/cunumeric/cunumeric_c.h index 99d9bea19..b38ab6620 100644 --- a/src/cunumeric/cunumeric_c.h +++ b/src/cunumeric/cunumeric_c.h @@ -60,6 +60,7 @@ enum CuNumericOpCode { CUNUMERIC_REPEAT, CUNUMERIC_SCALAR_UNARY_RED, CUNUMERIC_SEARCHSORTED, + CUNUMERIC_SELECT, CUNUMERIC_SOLVE, CUNUMERIC_SORT, CUNUMERIC_SYRK, diff --git a/src/cunumeric/index/select.cc b/src/cunumeric/index/select.cc new file mode 100644 index 000000000..d304ae179 --- /dev/null +++ b/src/cunumeric/index/select.cc @@ -0,0 +1,77 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/index/select.h" +#include "cunumeric/index/select_template.inl" + +namespace cunumeric { + +using namespace legate; + +template +struct SelectImplBody { + using VAL = legate_type_of; + + void operator()(const AccessorWO& out, + const std::vector>& condlist, + const std::vector>& choicelist, + VAL default_val, + const Rect& rect, + const Pitches& pitches, + bool dense) const + { + const size_t volume = rect.volume(); + uint32_t narrays = condlist.size(); +#ifdef DEBUG_CUNUMERIC + assert(narrays == choicelist.size()); +#endif + + if (dense) { + auto outptr = out.ptr(rect); + for (size_t idx = 0; idx < volume; ++idx) { outptr[idx] = default_val; } + for (int32_t c = (narrays - 1); c >= 0; c--) { + auto condptr = condlist[c].ptr(rect); + auto choiseptr = choicelist[c].ptr(rect); + for (int32_t idx = 0; idx < volume; idx++) { + if (condptr[idx]) outptr[idx] = choiseptr[idx]; + } + } + } else { + for (size_t idx = 0; idx < volume; ++idx) { + auto p = pitches.unflatten(idx, rect.lo); + out[p] = default_val; + } + for (int32_t c = (narrays - 1); c >= 0; c--) { + for (int32_t idx = 0; idx < volume; idx++) { + auto p = pitches.unflatten(idx, rect.lo); + if (condlist[c][p]) { out[p] = choicelist[c][p]; } + } + } + } + } +}; + +/*static*/ void SelectTask::cpu_variant(TaskContext& context) +{ + select_template(context); +} + +namespace // unnamed +{ +static void __attribute__((constructor)) register_tasks(void) { SelectTask::register_variants(); } +} // namespace + +} // namespace cunumeric diff --git a/src/cunumeric/index/select.cu b/src/cunumeric/index/select.cu new file mode 100644 index 000000000..5a11c44ca --- /dev/null +++ b/src/cunumeric/index/select.cu @@ -0,0 +1,115 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/index/select.h" +#include "cunumeric/index/select_template.inl" +#include "cunumeric/cuda_help.h" + +namespace cunumeric { + +template +__global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) + select_kernel_dense(VAL* outptr, + uint32_t narrays, + legate::Buffer condlist, + legate::Buffer choicelist, + VAL default_val, + int volume) +{ + const size_t idx = global_tid_1d(); + if (idx >= volume) return; + outptr[idx] = default_val; + for (int32_t c = (narrays - 1); c >= 0; c--) { + if (condlist[c][idx]) { outptr[idx] = choicelist[c][idx]; } + } +} + +template +__global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) + select_kernel(const AccessorWO out, + uint32_t narrays, + const legate::Buffer> condlist, + const legate::Buffer> choicelist, + VAL default_val, + const Rect rect, + const Pitches pitches, + int volume) +{ + const size_t tid = global_tid_1d(); + if (tid >= volume) return; + auto p = pitches.unflatten(tid, rect.lo); + out[p] = default_val; + for (int32_t c = (narrays - 1); c >= 0; c--) { + if (condlist[c][p]) { out[p] = choicelist[c][p]; } + } +} + +using namespace legate; + +template +struct SelectImplBody { + using VAL = legate_type_of; + + void operator()(const AccessorWO& out, + const std::vector>& condlist, + const std::vector>& choicelist, + VAL default_val, + const Rect& rect, + const Pitches& pitches, + bool dense) const + { + uint32_t narrays = condlist.size(); +#ifdef DEBUG_CUNUMERIC + assert(narrays == choicelist.size()); +#endif + const size_t blocks = (rect.volume() + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + + auto stream = get_cached_stream(); + + if (dense) { + auto cond_arr = create_buffer(condlist.size(), legate::Memory::Kind::Z_COPY_MEM); + for (uint32_t idx = 0; idx < condlist.size(); ++idx) cond_arr[idx] = condlist[idx].ptr(rect); + auto choice_arr = + create_buffer(choicelist.size(), legate::Memory::Kind::Z_COPY_MEM); + for (uint32_t idx = 0; idx < choicelist.size(); ++idx) + choice_arr[idx] = choicelist[idx].ptr(rect); + + VAL* outptr = out.ptr(rect); + select_kernel_dense<<>>( + outptr, narrays, cond_arr, choice_arr, default_val, rect.volume()); + + } else { // not dense + auto cond_arr = + create_buffer>(condlist.size(), legate::Memory::Kind::Z_COPY_MEM); + for (uint32_t idx = 0; idx < condlist.size(); ++idx) cond_arr[idx] = condlist[idx]; + auto choice_arr = + create_buffer>(choicelist.size(), legate::Memory::Kind::Z_COPY_MEM); + for (uint32_t idx = 0; idx < choicelist.size(); ++idx) choice_arr[idx] = choicelist[idx]; + + select_kernel<<>>( + out, narrays, cond_arr, choice_arr, default_val, rect, pitches, rect.volume()); + } + + CHECK_CUDA_STREAM(stream); + } +}; + +/*static*/ void SelectTask::gpu_variant(TaskContext& context) +{ + select_template(context); +} + +} // namespace cunumeric diff --git a/src/cunumeric/index/select.h b/src/cunumeric/index/select.h new file mode 100644 index 000000000..2179e12b7 --- /dev/null +++ b/src/cunumeric/index/select.h @@ -0,0 +1,43 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#pragma once + +#include "cunumeric/cunumeric.h" + +namespace cunumeric { + +struct SelectArgs { + const Array& out; + const std::vector& inputs; + const legate::Scalar& default_value; +}; + +class SelectTask : public CuNumericTask { + public: + static const int TASK_ID = CUNUMERIC_SELECT; + + public: + static void cpu_variant(legate::TaskContext& context); +#ifdef LEGATE_USE_OPENMP + static void omp_variant(legate::TaskContext& context); +#endif +#ifdef LEGATE_USE_CUDA + static void gpu_variant(legate::TaskContext& context); +#endif +}; + +} // namespace cunumeric diff --git a/src/cunumeric/index/select_omp.cc b/src/cunumeric/index/select_omp.cc new file mode 100644 index 000000000..98a9854cd --- /dev/null +++ b/src/cunumeric/index/select_omp.cc @@ -0,0 +1,76 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/index/select.h" +#include "cunumeric/index/select_template.inl" + +namespace cunumeric { + +using namespace legate; + +template +struct SelectImplBody { + using VAL = legate_type_of; + + void operator()(const AccessorWO& out, + const std::vector>& condlist, + const std::vector>& choicelist, + VAL default_val, + const Rect& rect, + const Pitches& pitches, + bool dense) const + { + const size_t volume = rect.volume(); + uint32_t narrays = condlist.size(); +#ifdef DEBUG_CUNUMERIC + assert(narrays == choicelist.size()); +#endif + + if (dense) { + auto outptr = out.ptr(rect); +#pragma omp parallel for schedule(static) + for (size_t idx = 0; idx < volume; ++idx) { outptr[idx] = default_val; } + for (int32_t c = (narrays - 1); c >= 0; c--) { + auto condptr = condlist[c].ptr(rect); + auto choiseptr = choicelist[c].ptr(rect); +#pragma omp parallel for schedule(static) + for (int32_t idx = 0; idx < volume; idx++) { + if (condptr[idx]) outptr[idx] = choiseptr[idx]; + } + } + } else { +#pragma omp parallel for schedule(static) + for (size_t idx = 0; idx < volume; ++idx) { + auto p = pitches.unflatten(idx, rect.lo); + out[p] = default_val; + } + for (int32_t c = (narrays - 1); c >= 0; c--) { +#pragma omp parallel for schedule(static) + for (int32_t idx = 0; idx < volume; idx++) { + auto p = pitches.unflatten(idx, rect.lo); + if (condlist[c][p]) { out[p] = choicelist[c][p]; } + } + } + } + } +}; + +/*static*/ void SelectTask::omp_variant(TaskContext& context) +{ + select_template(context); +} + +} // namespace cunumeric diff --git a/src/cunumeric/index/select_template.inl b/src/cunumeric/index/select_template.inl new file mode 100644 index 000000000..9ebaf19c6 --- /dev/null +++ b/src/cunumeric/index/select_template.inl @@ -0,0 +1,87 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#pragma once + +// Useful for IDEs +#include "cunumeric/index/select.h" +#include "cunumeric/pitches.h" + +namespace cunumeric { + +using namespace legate; + +template +struct SelectImplBody; + +template +struct SelectImpl { + template + void operator()(SelectArgs& args) const + { + using VAL = legate_type_of; + auto out_rect = args.out.shape(); + + Pitches pitches; + size_t volume = pitches.flatten(out_rect); + if (volume == 0) return; + + auto out = args.out.write_accessor(out_rect); + +#ifndef LEGATE_BOUNDS_CHECKS + // Check to see if this is dense or not + bool dense = out.accessor.is_dense_row_major(out_rect); +#else + // No dense execution if we're doing bounds checks + bool dense = false; +#endif + + std::vector> condlist; + condlist.reserve(args.inputs.size() / 2); + for (int32_t i = 0; i < args.inputs.size() / 2; i++) { + auto rect_c = args.inputs[i].shape(); +#ifdef DEBUG_CUNUMERIC + assert(rect_c == out_rect); +#endif + condlist.push_back(args.inputs[i].read_accessor(rect_c)); + dense = dense && condlist.back().accessor.is_dense_row_major(out_rect); + } + + std::vector> choicelist; + choicelist.reserve(args.inputs.size() / 2); + for (int32_t i = args.inputs.size() / 2; i < args.inputs.size(); i++) { + auto rect_c = args.inputs[i].shape(); +#ifdef DEBUG_CUNUMERIC + assert(rect_c == out_rect); +#endif + choicelist.push_back(args.inputs[i].read_accessor(rect_c)); + dense = dense && choicelist.back().accessor.is_dense_row_major(out_rect); + } + + VAL default_value = args.default_value.value(); + SelectImplBody()( + out, condlist, choicelist, default_value, out_rect, pitches, dense); + } +}; + +template +static void select_template(TaskContext& context) +{ + SelectArgs args{context.outputs()[0], context.inputs(), context.scalars()[0]}; + double_dispatch(args.out.dim(), args.out.code(), SelectImpl{}, args); +} + +} // namespace cunumeric diff --git a/tests/integration/test_index_routines.py b/tests/integration/test_index_routines.py index 4488b90b5..d1c42805f 100644 --- a/tests/integration/test_index_routines.py +++ b/tests/integration/test_index_routines.py @@ -268,6 +268,133 @@ def test_out_invalid_shape(self): num.choose(self.a, self.choices, out=aout) +DIM = 7 + +SELECT_SHAPES = ( + (DIM,), + (1, 1), + (1, DIM), + (DIM, 1), + (DIM, 0), + (DIM, DIM), + (1, 1, 1), + (1, 0, DIM), + (DIM, 1, 1), + (1, DIM, 1), + (1, 1, DIM), + (DIM, DIM, DIM), +) + +DEFAULTS = (0, -100, 5) + + +@pytest.mark.parametrize("size", SELECT_SHAPES) +def test_select(size): + # test with 2 conditions/choices + no default passed + arr = np.random.randint(-15, 15, size=size) + cond_np1 = arr > 1 + cond_num1 = num.array(cond_np1) + cond_np2 = arr < 0 + cond_num2 = num.array(cond_np2) + choice_np1 = arr * 10 + choice_num1 = num.array(choice_np1) + choice_np2 = arr * 2 + choice_num2 = num.array(choice_np2) + res_np = np.select( + ( + cond_np1, + cond_np2, + ), + ( + choice_np1, + choice_np2, + ), + ) + res_num = num.select( + ( + cond_num1, + cond_num2, + ), + ( + choice_num1, + choice_num2, + ), + ) + assert np.array_equal(res_np, res_num) + + # test with all False + cond_np = arr > 100 + cond_num = num.array(cond_np) + choice_np = arr * 100 + choice_num = num.array(choice_np) + res_np = np.select(cond_np, choice_np) + res_num = num.select(cond_num, choice_num) + assert np.array_equal(res_np, res_num) + + # test with all True + cond_np = arr < 100 + cond_num = num.array(cond_np) + choice_np = arr * 10 + choice_num = num.array(choice_np) + res_np = np.select(cond_np, choice_np) + res_num = num.select(cond_num, choice_num) + assert np.array_equal(res_np, res_num) + + +def test_select_maxdim(): + for ndim in range(2, LEGATE_MAX_DIM + 1): + a_shape = tuple(np.random.randint(1, 9) for i in range(ndim)) + arr = mk_seq_array(np, a_shape) + condlist_np = list() + choicelist_np = list() + condlist_num = list() + choicelist_num = list() + nlist = np.random.randint(1, 5) + for nl in range(0, nlist): + arr_con = arr > nl * 2 + arr_ch = arr * nl + condlist_np += (arr_con,) + choicelist_np += (arr_ch,) + condlist_num += (num.array(arr_con),) + choicelist_num += (num.array(arr_ch),) + res_np = np.select(condlist_np, choicelist_np) + res_num = num.select(condlist_num, choicelist_num) + assert np.array_equal(res_np, res_num) + + +@pytest.mark.parametrize("size", SELECT_SHAPES) +@pytest.mark.parametrize("default", DEFAULTS) +def test_select_default(size, default): + arr_np = np.random.randint(-5, 5, size=size) + cond_np = arr_np > 1 + cond_num = num.array(cond_np) + choice_np = arr_np**2 + choice_num = num.array(choice_np) + res_np = np.select(cond_np, choice_np, default) + res_num = num.select(cond_num, choice_num, default) + assert np.array_equal(res_np, res_num) + + +SELECT_ZERO_SHAPES = ( + (0,), + (0, 1), +) + + +@pytest.mark.parametrize("size", SELECT_ZERO_SHAPES) +def test_select_zero_shape(size): + arr_np = np.random.randint(-15, 15, size=size) + cond_np = arr_np > 1 + cond_num = num.array(cond_np) + choice_np = arr_np * 10 + choice_num = num.array(choice_np) + msg = "select with an empty condition list is not possible" + with pytest.raises(ValueError, match=msg): + np.select(cond_np, choice_np) + with pytest.raises(ValueError, match=msg): + num.select(cond_num, choice_num) + + def test_diagonal(): ad = np.arange(24).reshape(4, 3, 2) num_ad = num.array(ad) diff --git a/tests/unit/cunumeric/test_config.py b/tests/unit/cunumeric/test_config.py index 6f8f43df5..f829c0279 100644 --- a/tests/unit/cunumeric/test_config.py +++ b/tests/unit/cunumeric/test_config.py @@ -144,6 +144,7 @@ def test_CuNumericOpCode() -> None: "RAND", "READ", "REPEAT", + "SELECT", "SCALAR_UNARY_RED", "SCAN_GLOBAL", "SCAN_LOCAL",