From 9bd6d410250f0590bec3c9cfe70c67c3abae09d8 Mon Sep 17 00:00:00 2001 From: Manolis Papadakis Date: Thu, 2 Nov 2023 15:56:13 -0700 Subject: [PATCH 01/17] Support np.select operation --- cunumeric/config.py | 2 + cunumeric/deferred.py | 27 ++++++++- cunumeric/eager.py | 21 +++++++ cunumeric/module.py | 73 +++++++++++++++++++++++- cunumeric/thunk.py | 11 +++- cunumeric_cpp.cmake | 7 ++- docs/cunumeric/source/api/indexing.rst | 3 +- src/cunumeric/cunumeric_c.h | 1 + src/cunumeric/index/select.cc | 18 ++++++ src/cunumeric/index/select.cu | 43 ++++++++++++++ src/cunumeric/index/select.h | 19 ++++++ src/cunumeric/index/select_omp.cc | 18 ++++++ src/cunumeric/index/select_template.inl | 21 +++++++ tests/integration/test_index_routines.py | 3 + 14 files changed, 259 insertions(+), 8 deletions(-) create mode 100644 src/cunumeric/index/select.cc create mode 100644 src/cunumeric/index/select.cu create mode 100644 src/cunumeric/index/select.h create mode 100644 src/cunumeric/index/select_omp.cc create mode 100644 src/cunumeric/index/select_template.inl diff --git a/cunumeric/config.py b/cunumeric/config.py index bdea334a1..5020f5a4a 100644 --- a/cunumeric/config.py +++ b/cunumeric/config.py @@ -196,6 +196,7 @@ class _CunumericSharedLib: CUNUMERIC_SCAN_PROD: int CUNUMERIC_SCAN_SUM: int CUNUMERIC_SEARCHSORTED: int + CUNUMERIC_SELECT: int CUNUMERIC_SOLVE: int CUNUMERIC_SORT: int CUNUMERIC_SYRK: int @@ -363,6 +364,7 @@ class CuNumericOpCode(IntEnum): SCAN_GLOBAL = _cunumeric.CUNUMERIC_SCAN_GLOBAL SCAN_LOCAL = _cunumeric.CUNUMERIC_SCAN_LOCAL SEARCHSORTED = _cunumeric.CUNUMERIC_SEARCHSORTED + SELECT = _cunumeric.SELECT SOLVE = _cunumeric.CUNUMERIC_SOLVE SORT = _cunumeric.CUNUMERIC_SORT SYRK = _cunumeric.CUNUMERIC_SYRK diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index a67d9912d..fd3844f53 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -20,7 +20,7 @@ from enum import IntEnum, unique from functools import reduce, wraps from inspect import signature -from itertools import product +from itertools import chain, product from typing import ( TYPE_CHECKING, Any, @@ -57,7 +57,7 @@ from .linalg.solve import solve from .sort import sort from .thunk import NumPyThunk -from .utils import is_advanced_indexing +from .utils import is_advanced_indexing, to_core_dtype if TYPE_CHECKING: import numpy.typing as npt @@ -261,7 +261,7 @@ def __init__( super().__init__(runtime, base.type.to_numpy_dtype()) assert base is not None assert isinstance(base, Store) - self.base: Any = base # a Legate Store + self.base = base # a Legate Store self.numpy_array = ( None if numpy_array is None else weakref.ref(numpy_array) ) @@ -1752,6 +1752,27 @@ def choose(self, rhs: Any, *args: Any) -> None: task.add_alignment(index, c) task.execute() + def select( + self, + condlist: Iterable[Any], + choicelist: Iterable[Any], + default: npt.NDArray[Any], + ) -> None: + condlist_ = tuple(self.runtime.to_deferred_array(c) for c in condlist) + choicelist_ = tuple( + self.runtime.to_deferred_array(c) for c in choicelist + ) + + task = self.context.create_auto_task(CuNumericOpCode.SELECT) + out_arr = self.base + task.add_output(out_arr) + for c in chain(condlist_, choicelist_): + c_arr = c._broadcast(self.shape) + task.add_input(c_arr) + task.add_alignment(c_arr, out_arr) + task.add_scalar_arg(default, to_core_dtype(default.dtype)) + task.execute() + # Create or extract a diagonal from a matrix @auto_convert("rhs") def _diag_helper( diff --git a/cunumeric/eager.py b/cunumeric/eager.py index 63284eb94..1e9137393 100644 --- a/cunumeric/eager.py +++ b/cunumeric/eager.py @@ -19,6 +19,7 @@ Any, Callable, Dict, + Iterable, Optional, Sequence, Union, @@ -629,6 +630,26 @@ def choose(self, rhs: Any, *args: Any) -> None: choices = tuple(c.array for c in args) self.array[:] = np.choose(rhs.array, choices, mode="raise") + def select( + self, + condlist: Iterable[Any], + choicelist: Iterable[Any], + default: npt.NDArray[Any], + ) -> None: + self.check_eager_args(*condlist, *choicelist) + if self.deferred is not None: + self.deferred.select( + condlist, + choicelist, + default, + ) + else: + self.array[:] = np.select( + tuple(c.array for c in condlist), + tuple(c.array for c in choicelist), + default, + ) + def _diag_helper( self, rhs: Any, offset: int, naxes: int, extract: bool, trace: bool ) -> None: diff --git a/cunumeric/module.py b/cunumeric/module.py index e2bbc78f7..7ba081437 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -378,7 +378,7 @@ def zeros_like( def full( shape: NdShapeLike, - value: Union[int, float], + value: Any, dtype: Optional[npt.DTypeLike] = None, ) -> ndarray: """ @@ -3664,6 +3664,77 @@ def choose( return a.choose(choices=choices, out=out, mode=mode) +def select( + condlist: Sequence[npt.ArrayLike | ndarray], + choicelist: Sequence[npt.ArrayLike | ndarray], + default: Any = 0, +) -> ndarray: + """ + Return an array drawn from elements in choicelist, depending on conditions. + + Parameters + ---------- + condlist : list of bool ndarrays + The list of conditions which determine from which array in `choicelist` + the output elements are taken. When multiple conditions are satisfied, + the first one encountered in `condlist` is used. + choicelist : list of ndarrays + The list of arrays from which the output elements are taken. It has + to be of the same length as `condlist`. + default : scalar, optional + The element inserted in `output` when all conditions evaluate to False. + + Returns + ------- + output : ndarray + The output at position m is the m-th element of the array in + `choicelist` where the m-th element of the corresponding array in + `condlist` is True. + + See Also + -------- + numpy.select + + Availability + -------- + Multiple GPUs, Multiple CPUs + """ + + if len(condlist) != len(choicelist): + raise ValueError( + "list of cases must be same length as list of conditions" + ) + if len(condlist) == 0: + raise ValueError("select with an empty condition list is not possible") + + condlist_ = tuple(convert_to_cunumeric_ndarray(c) for c in condlist) + for i, c in enumerate(condlist_): + if c.dtype != bool: + raise TypeError( + f"invalid entry {i} in condlist: should be boolean ndarray" + ) + + choicelist_ = tuple(convert_to_cunumeric_ndarray(c) for c in choicelist) + common_type = np.result_type(*choicelist_, default) + args = condlist_ + choicelist_ + choicelist_ = tuple( + c._maybe_convert(common_type, args) for c in choicelist_ + ) + default_ = np.array(default, dtype=common_type) + + out_shape = np.broadcast_shapes( + *(c.shape for c in condlist_), + *(c.shape for c in choicelist_), + ) + out = ndarray(shape=out_shape, dtype=common_type, inputs=args) + out._thunk.select( + tuple(c._thunk for c in condlist_), + tuple(c._thunk for c in choicelist_), + default_, + ) + return out + + @add_boilerplate("condition", "a") def compress( condition: ndarray, diff --git a/cunumeric/thunk.py b/cunumeric/thunk.py index 62417271a..0a831bce9 100644 --- a/cunumeric/thunk.py +++ b/cunumeric/thunk.py @@ -15,7 +15,7 @@ from __future__ import annotations from abc import ABC, abstractmethod, abstractproperty -from typing import TYPE_CHECKING, Any, Optional, Sequence, Union +from typing import TYPE_CHECKING, Any, Iterable, Optional, Sequence, Union from .config import ConvertCode @@ -191,6 +191,15 @@ def contract( def choose(self, rhs: Any, *args: Any) -> None: ... + @abstractmethod + def select( + self, + condlist: Iterable[Any], + choicelist: Iterable[Any], + default: npt.NDArray[Any], + ) -> None: + ... + @abstractmethod def _diag_helper( self, rhs: Any, offset: int, naxes: int, extract: bool, trace: bool diff --git a/cunumeric_cpp.cmake b/cunumeric_cpp.cmake index 4270962ba..0dd81b5c7 100644 --- a/cunumeric_cpp.cmake +++ b/cunumeric_cpp.cmake @@ -137,10 +137,11 @@ list(APPEND cunumeric_SOURCES src/cunumeric/nullary/window.cc src/cunumeric/index/advanced_indexing.cc src/cunumeric/index/choose.cc + src/cunumeric/index/putmask.cc src/cunumeric/index/repeat.cc + src/cunumeric/index/select.cc src/cunumeric/index/wrap.cc src/cunumeric/index/zip.cc - src/cunumeric/index/putmask.cc src/cunumeric/item/read.cc src/cunumeric/item/write.cc src/cunumeric/matrix/contract.cc @@ -193,6 +194,7 @@ if(Legion_USE_OpenMP) src/cunumeric/index/choose_omp.cc src/cunumeric/index/putmask_omp.cc src/cunumeric/index/repeat_omp.cc + src/cunumeric/index/select_omp.cc src/cunumeric/index/wrap_omp.cc src/cunumeric/index/zip_omp.cc src/cunumeric/matrix/contract_omp.cc @@ -239,10 +241,11 @@ if(Legion_USE_CUDA) src/cunumeric/nullary/window.cu src/cunumeric/index/advanced_indexing.cu src/cunumeric/index/choose.cu + src/cunumeric/index/putmask.cu src/cunumeric/index/repeat.cu + src/cunumeric/index/select.cu src/cunumeric/index/wrap.cu src/cunumeric/index/zip.cu - src/cunumeric/index/putmask.cu src/cunumeric/item/read.cu src/cunumeric/item/write.cu src/cunumeric/matrix/contract.cu diff --git a/docs/cunumeric/source/api/indexing.rst b/docs/cunumeric/source/api/indexing.rst index ab02bbcc4..e1a3358cb 100644 --- a/docs/cunumeric/source/api/indexing.rst +++ b/docs/cunumeric/source/api/indexing.rst @@ -32,6 +32,7 @@ Indexing-like operations compress diag diagonal + select take take_along_axis @@ -41,7 +42,7 @@ Inserting data into arrays .. autosummary:: :toctree: generated/ - + fill_diagonal put putmask diff --git a/src/cunumeric/cunumeric_c.h b/src/cunumeric/cunumeric_c.h index b5b392835..d0be7aa32 100644 --- a/src/cunumeric/cunumeric_c.h +++ b/src/cunumeric/cunumeric_c.h @@ -59,6 +59,7 @@ enum CuNumericOpCode { CUNUMERIC_REPEAT, CUNUMERIC_SCALAR_UNARY_RED, CUNUMERIC_SEARCHSORTED, + CUNUMERIC_SELECT, CUNUMERIC_SOLVE, CUNUMERIC_SORT, CUNUMERIC_SYRK, diff --git a/src/cunumeric/index/select.cc b/src/cunumeric/index/select.cc new file mode 100644 index 000000000..87e3983aa --- /dev/null +++ b/src/cunumeric/index/select.cc @@ -0,0 +1,18 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/index/select.h" +#include "cunumeric/index/select_template.inl" diff --git a/src/cunumeric/index/select.cu b/src/cunumeric/index/select.cu new file mode 100644 index 000000000..a03a4dcda --- /dev/null +++ b/src/cunumeric/index/select.cu @@ -0,0 +1,43 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/index/select.h" +#include "cunumeric/index/select_template.inl" +#include "cunumeric/cuda_help.h" + +namespace cunumeric { + +template +__global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) + select_kernel_dense(VAL* outptr, + uint32_t narrays, + legate::Buffer condlist, + legate::Buffer choicelist, + VAL default_val, + int volume) +{ + const size_t idx = global_tid_1d(); + if (idx >= volume) return; + for (uint32_t c = 0; c < narrays; ++c) { + if (condlist[c][idx]) { + outptr[idx] = choicelist[c][idx]; + return; + } + } + outptr[idx] = default_val; +} + +} // namespace cunumeric diff --git a/src/cunumeric/index/select.h b/src/cunumeric/index/select.h new file mode 100644 index 000000000..9d14df5bb --- /dev/null +++ b/src/cunumeric/index/select.h @@ -0,0 +1,19 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#pragma once + +#include "cunumeric/cunumeric.h" diff --git a/src/cunumeric/index/select_omp.cc b/src/cunumeric/index/select_omp.cc new file mode 100644 index 000000000..87e3983aa --- /dev/null +++ b/src/cunumeric/index/select_omp.cc @@ -0,0 +1,18 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/index/select.h" +#include "cunumeric/index/select_template.inl" diff --git a/src/cunumeric/index/select_template.inl b/src/cunumeric/index/select_template.inl new file mode 100644 index 000000000..a6019594c --- /dev/null +++ b/src/cunumeric/index/select_template.inl @@ -0,0 +1,21 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#pragma once + +// Useful for IDEs +#include "cunumeric/index/select.h" +#include "cunumeric/pitches.h" diff --git a/tests/integration/test_index_routines.py b/tests/integration/test_index_routines.py index c86f62e97..2068bb1cd 100644 --- a/tests/integration/test_index_routines.py +++ b/tests/integration/test_index_routines.py @@ -268,6 +268,9 @@ def test_out_invalid_shape(self): num.choose(self.a, self.choices, out=aout) +# TODO: test select + + def test_diagonal(): ad = np.arange(24).reshape(4, 3, 2) num_ad = num.array(ad) From d3a89fae95923328a574dbee07e6713d83e0b40d Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Mon, 6 Nov 2023 18:54:08 -0800 Subject: [PATCH 02/17] applying fixes for pre-commit provided by Bryan --- cunumeric/config.py | 2 +- cunumeric/deferred.py | 42 ++++++++++++++++++++++++------------------ cunumeric/eager.py | 2 +- cunumeric/thunk.py | 2 +- 4 files changed, 27 insertions(+), 21 deletions(-) diff --git a/cunumeric/config.py b/cunumeric/config.py index 5020f5a4a..f01cbb6be 100644 --- a/cunumeric/config.py +++ b/cunumeric/config.py @@ -364,7 +364,7 @@ class CuNumericOpCode(IntEnum): SCAN_GLOBAL = _cunumeric.CUNUMERIC_SCAN_GLOBAL SCAN_LOCAL = _cunumeric.CUNUMERIC_SCAN_LOCAL SEARCHSORTED = _cunumeric.CUNUMERIC_SEARCHSORTED - SELECT = _cunumeric.SELECT + SELECT = _cunumeric.CUNUMERIC_SELECT SOLVE = _cunumeric.CUNUMERIC_SOLVE SORT = _cunumeric.CUNUMERIC_SORT SYRK = _cunumeric.CUNUMERIC_SYRK diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index fd3844f53..597b899ff 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -36,6 +36,7 @@ import legate.core.types as ty import numpy as np from legate.core import Annotation, Future, ReductionOp, Store +from legate.core.store import RegionField from legate.core.utils import OrderedSet from numpy.core.numeric import ( # type: ignore [attr-defined] normalize_axis_tuple, @@ -270,11 +271,13 @@ def __str__(self) -> str: return f"DeferredArray(base: {self.base})" @property - def storage(self) -> Union[Future, tuple[Region, FieldID]]: + def storage(self) -> Union[Future, tuple[Region, Union[int, FieldID]]]: storage = self.base.storage if self.base.kind == Future: + assert isinstance(storage, Future) return storage else: + assert isinstance(storage, RegionField) return (storage.region, storage.field.field_id) @property @@ -402,6 +405,7 @@ def scalar(self) -> bool: def get_scalar_array(self) -> npt.NDArray[Any]: assert self.scalar + assert isinstance(self.base.storage, Future) buf = self.base.storage.get_buffer(self.dtype.itemsize) result = np.frombuffer(buf, dtype=self.dtype, count=1) return result.reshape(()) @@ -770,10 +774,11 @@ def _create_indexing_array( store = self.base rhs = self + computed_key: tuple[Any, ...] if isinstance(key, NumPyThunk): - key = (key,) + computed_key = (key,) assert isinstance(key, tuple) - key = self._unpack_ellipsis(key, self.ndim) + computed_key = self._unpack_ellipsis(key, self.ndim) # the index where the first index_array is passed to the [] operator start_index = -1 @@ -788,7 +793,7 @@ def _create_indexing_array( tuple_of_arrays: tuple[Any, ...] = () # First, we need to check if transpose is needed - for dim, k in enumerate(key): + for dim, k in enumerate(computed_key): if np.isscalar(k) or isinstance(k, NumPyThunk): if start_index == -1: start_index = dim @@ -813,17 +818,19 @@ def _create_indexing_array( ) transpose_indices += post_indices post_indices = tuple( - i for i in range(len(key)) if i not in key_transpose_indices + i + for i in range(len(computed_key)) + if i not in key_transpose_indices ) key_transpose_indices += post_indices store = store.transpose(transpose_indices) - key = tuple(key[i] for i in key_transpose_indices) + key = tuple(computed_key[i] for i in key_transpose_indices) shift = 0 - for dim, k in enumerate(key): + for dim, k in enumerate(computed_key): if np.isscalar(k): if k < 0: # type: ignore [operator] - k += store.shape[dim + shift] + k += store.shape[dim + shift] # type: ignore [operator] store = store.project(dim + shift, k) shift -= 1 elif k is np.newaxis: @@ -831,7 +838,7 @@ def _create_indexing_array( elif isinstance(k, slice): k, store = self._slice_store(k, store, dim + shift) elif isinstance(k, NumPyThunk): - if not isinstance(key, DeferredArray): + if not isinstance(computed_key, DeferredArray): k = self.runtime.to_deferred_array(k) if k.dtype == bool: for i in range(k.ndim): @@ -900,7 +907,7 @@ def _get_view(self, key: Any) -> DeferredArray: k, store = self._slice_store(k, store, dim + shift) elif np.isscalar(k): if k < 0: # type: ignore [operator] - k += store.shape[dim + shift] + k += store.shape[dim + shift] # type: ignore [operator] store = store.project(dim + shift, k) shift -= 1 else: @@ -1329,10 +1336,9 @@ def swapaxes(self, axis1: int, axis2: int) -> DeferredArray: dims = list(range(self.ndim)) dims[axis1], dims[axis2] = dims[axis2], dims[axis1] - result = self.base.transpose(dims) - result = DeferredArray(self.runtime, result) + result = self.base.transpose(tuple(dims)) - return result + return DeferredArray(self.runtime, result) # Convert the source array to the destination array @auto_convert("rhs") @@ -1738,8 +1744,8 @@ def choose(self, rhs: Any, *args: Any) -> None: out_arr = self.base # broadcast input array and all choices arrays to the same shape - index = index_arr._broadcast(out_arr.shape) - ch_tuple = tuple(c._broadcast(out_arr.shape) for c in ch_def) + index = index_arr._broadcast(out_arr.shape.extents) + ch_tuple = tuple(c._broadcast(out_arr.shape.extents) for c in ch_def) task = self.context.create_auto_task(CuNumericOpCode.CHOOSE) task.add_output(out_arr) @@ -1985,9 +1991,9 @@ def tile(self, rhs: Any, reps: Union[Any, Sequence[int]]) -> None: def transpose( self, axes: Union[None, tuple[int, ...], list[int]] ) -> DeferredArray: - result = self.base.transpose(axes) - result = DeferredArray(self.runtime, result) - return result + computed_axes = tuple(axes) if axes is not None else () + result = self.base.transpose(computed_axes) + return DeferredArray(self.runtime, result) @auto_convert("rhs") def trilu(self, rhs: Any, k: int, lower: bool) -> None: diff --git a/cunumeric/eager.py b/cunumeric/eager.py index 1e9137393..56a056b6f 100644 --- a/cunumeric/eager.py +++ b/cunumeric/eager.py @@ -235,7 +235,7 @@ def __init__( self.escaped = False @property - def storage(self) -> Union[Future, tuple[Region, FieldID]]: + def storage(self) -> Union[Future, tuple[Region, Union[int, FieldID]]]: if self.deferred is None: self.to_deferred_array() diff --git a/cunumeric/thunk.py b/cunumeric/thunk.py index 0a831bce9..68aafb6c9 100644 --- a/cunumeric/thunk.py +++ b/cunumeric/thunk.py @@ -74,7 +74,7 @@ def size(self) -> int: # Abstract methods @abstractproperty - def storage(self) -> Union[Future, tuple[Region, FieldID]]: + def storage(self) -> Union[Future, tuple[Region, Union[int, FieldID]]]: """Return the Legion storage primitive for this NumPy thunk""" ... From fa32b7a4dd9a246feae71f1f25d471bfb98b6dea Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Wed, 8 Nov 2023 12:30:48 -0800 Subject: [PATCH 03/17] implementing C++ part for np.select --- cunumeric/eager.py | 2 +- src/cunumeric/index/select.cc | 59 +++++++++++ src/cunumeric/index/select.cu | 88 +++++++++++++++- src/cunumeric/index/select.h | 24 +++++ src/cunumeric/index/select_omp.cc | 61 +++++++++++ src/cunumeric/index/select_template.inl | 64 ++++++++++++ tests/integration/test_index_routines.py | 126 ++++++++++++++++++++++- 7 files changed, 417 insertions(+), 7 deletions(-) diff --git a/cunumeric/eager.py b/cunumeric/eager.py index 56a056b6f..91efd777a 100644 --- a/cunumeric/eager.py +++ b/cunumeric/eager.py @@ -644,7 +644,7 @@ def select( default, ) else: - self.array[:] = np.select( + self.array[...] = np.select( tuple(c.array for c in condlist), tuple(c.array for c in choicelist), default, diff --git a/src/cunumeric/index/select.cc b/src/cunumeric/index/select.cc index 87e3983aa..3ebd52746 100644 --- a/src/cunumeric/index/select.cc +++ b/src/cunumeric/index/select.cc @@ -16,3 +16,62 @@ #include "cunumeric/index/select.h" #include "cunumeric/index/select_template.inl" + +namespace cunumeric { + +using namespace legate; + +template +struct SelectImplBody { + using VAL = legate_type_of; + + void operator()(const AccessorWO& out, + const std::vector>& condlist, + const std::vector>& choicelist, + VAL default_val, + const Rect& rect, + const Pitches& pitches, + bool dense) const + { + const size_t volume = rect.volume(); + uint32_t narrays = condlist.size(); +#ifdef DEBUG_CUNUMERIC + assert(narrays == choicelist.size()); +#endif + + if (dense) { + auto outptr = out.ptr(rect); + for (size_t idx = 0; idx < volume; ++idx) outptr[idx] = default_val; + for (int32_t c = (narrays - 1); c >= 0; c--) { + auto condptr = condlist[c].ptr(rect); + auto choiseptr = choicelist[c].ptr(rect); + for (int32_t idx = (volume - 1); idx >= 0; idx--) { + if (condptr[idx]) outptr[idx] = choiseptr[idx]; + } + } + } else { + for (size_t idx = 0; idx < volume; ++idx) { + auto p = pitches.unflatten(idx, rect.lo); + out[p] = default_val; + } + for (int32_t c = (narrays - 1); c >= 0; c--) { + for (int32_t idx = (volume - 1); idx >= 0; idx--) { + auto p = pitches.unflatten(idx, rect.lo); + if (condlist[c][p]) out[p] = choicelist[c][p]; + } + } + } + } +}; + +/*static*/ void SelectTask::cpu_variant(TaskContext& context) +{ + select_template(context); +} + +namespace // unnamed +{ +static void __attribute__((constructor)) register_tasks(void) { SelectTask::register_variants(); } +} // namespace + +} // namespace cunumeric diff --git a/src/cunumeric/index/select.cu b/src/cunumeric/index/select.cu index a03a4dcda..e99fac0e9 100644 --- a/src/cunumeric/index/select.cu +++ b/src/cunumeric/index/select.cu @@ -31,13 +31,91 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) { const size_t idx = global_tid_1d(); if (idx >= volume) return; - for (uint32_t c = 0; c < narrays; ++c) { - if (condlist[c][idx]) { - outptr[idx] = choicelist[c][idx]; - return; + outptr[idx] = default_val; + for (int32_t c = (narrays - 1); c >= 0; c--) { + if (condlist[c][idx]) { outptr[idx] = choicelist[c][idx]; } + } +} + +template +__global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) + select_kernel(const AccessorWO out, + uint32_t narrays, + const legate::Buffer> condlist, + const legate::Buffer> choicelist, + VAL default_val, + const Rect rect, + const Pitches pitches, + int out_size, + int volume) +{ + const size_t tid = global_tid_1d(); + if (tid >= out_size) return; + for (int32_t idx = (volume - out_size + tid); idx >= 0; idx -= out_size) { + auto p = pitches.unflatten(idx, rect.lo); + out[p] = default_val; + } + __syncthreads(); + for (int32_t c = (narrays - 1); c >= 0; c--) { + for (int32_t idx = (volume - out_size + tid); idx >= 0; idx -= out_size) { + auto p = pitches.unflatten(idx, rect.lo); + if (condlist[c][p]) { out[p] = choicelist[c][p]; } } } - outptr[idx] = default_val; +} + +using namespace legate; + +template +struct SelectImplBody { + using VAL = legate_type_of; + + void operator()(const AccessorWO& out, + const std::vector>& condlist, + const std::vector>& choicelist, + VAL default_val, + const Rect& rect, + const Pitches& pitches, + bool dense) const + { + const size_t out_size = rect.hi[0] - rect.lo[0] + 1; + uint32_t narrays = condlist.size(); +#ifdef DEBUG_CUNUMERIC + assert(narrays == choicelist.size()); +#endif + const size_t blocks = (out_size + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + + auto stream = get_cached_stream(); + if (dense && (DIM <= 1 || rect.volume() == 0)) { + auto cond_arr = create_buffer(condlist.size(), legate::Memory::Kind::Z_COPY_MEM); + for (uint32_t idx = 0; idx < condlist.size(); ++idx) cond_arr[idx] = condlist[idx].ptr(rect); + auto choice_arr = + create_buffer(choicelist.size(), legate::Memory::Kind::Z_COPY_MEM); + for (uint32_t idx = 0; idx < choicelist.size(); ++idx) + choice_arr[idx] = choicelist[idx].ptr(rect); + VAL* outptr = out.ptr(rect); + select_kernel_dense<<>>( + outptr, narrays, cond_arr, choice_arr, default_val, out_size); + } else { + auto cond_arr = + create_buffer>(condlist.size(), legate::Memory::Kind::Z_COPY_MEM); + for (uint32_t idx = 0; idx < condlist.size(); ++idx) cond_arr[idx] = condlist[idx]; + + auto choice_arr = + create_buffer>(choicelist.size(), legate::Memory::Kind::Z_COPY_MEM); + for (uint32_t idx = 0; idx < choicelist.size(); ++idx) choice_arr[idx] = choicelist[idx]; + if (out_size == 0) return; + select_kernel<<>>( + out, narrays, cond_arr, choice_arr, default_val, rect, pitches, out_size, rect.volume()); + } + + CHECK_CUDA_STREAM(stream); + } +}; + +/*static*/ void SelectTask::gpu_variant(TaskContext& context) +{ + select_template(context); } } // namespace cunumeric diff --git a/src/cunumeric/index/select.h b/src/cunumeric/index/select.h index 9d14df5bb..2179e12b7 100644 --- a/src/cunumeric/index/select.h +++ b/src/cunumeric/index/select.h @@ -17,3 +17,27 @@ #pragma once #include "cunumeric/cunumeric.h" + +namespace cunumeric { + +struct SelectArgs { + const Array& out; + const std::vector& inputs; + const legate::Scalar& default_value; +}; + +class SelectTask : public CuNumericTask { + public: + static const int TASK_ID = CUNUMERIC_SELECT; + + public: + static void cpu_variant(legate::TaskContext& context); +#ifdef LEGATE_USE_OPENMP + static void omp_variant(legate::TaskContext& context); +#endif +#ifdef LEGATE_USE_CUDA + static void gpu_variant(legate::TaskContext& context); +#endif +}; + +} // namespace cunumeric diff --git a/src/cunumeric/index/select_omp.cc b/src/cunumeric/index/select_omp.cc index 87e3983aa..f015208a7 100644 --- a/src/cunumeric/index/select_omp.cc +++ b/src/cunumeric/index/select_omp.cc @@ -16,3 +16,64 @@ #include "cunumeric/index/select.h" #include "cunumeric/index/select_template.inl" + +namespace cunumeric { + +using namespace legate; + +template +struct SelectImplBody { + using VAL = legate_type_of; + + void operator()(const AccessorWO& out, + const std::vector>& condlist, + const std::vector>& choicelist, + VAL default_val, + const Rect& rect, + const Pitches& pitches, + bool dense) const + { + const size_t volume = rect.volume(); + uint32_t narrays = condlist.size(); +#ifdef DEBUG_CUNUMERIC + assert(narrays == choicelist.size()); +#endif + + if (dense && DIM <= 1) { + auto outptr = out.ptr(rect); +#pragma omp parallel for schedule(static) + for (size_t idx = 0; idx < volume; ++idx) outptr[idx] = default_val; + for (int32_t c = (narrays - 1); c >= 0; c--) { + auto condptr = condlist[c].ptr(rect); + auto choiseptr = choicelist[c].ptr(rect); +#pragma omp parallel for schedule(static) + for (int32_t idx = (volume - 1); idx >= 0; idx--) { + if (condptr[idx]) outptr[idx] = choiseptr[idx]; + } + } + } else { +#pragma omp parallel for schedule(static) + for (size_t idx = 0; idx < volume; ++idx) { + auto p = pitches.unflatten(idx, rect.lo); + out[p] = default_val; + } + const size_t out_size = rect.hi[0] - rect.lo[0] + 1; + for (int32_t c = (narrays - 1); c >= 0; c--) { +#pragma omp parallel for schedule(static) + for (int32_t out_idx = 0; out_idx <= out_size; out_idx++) { + for (int32_t idx = (volume - out_size + out_idx); idx >= 0; idx -= out_size) { + auto p = pitches.unflatten(idx, rect.lo); + if (condlist[c][p]) out[p] = choicelist[c][p]; + } + } + } + } + } +}; + +/*static*/ void SelectTask::omp_variant(TaskContext& context) +{ + select_template(context); +} + +} // namespace cunumeric diff --git a/src/cunumeric/index/select_template.inl b/src/cunumeric/index/select_template.inl index a6019594c..c37ff19aa 100644 --- a/src/cunumeric/index/select_template.inl +++ b/src/cunumeric/index/select_template.inl @@ -19,3 +19,67 @@ // Useful for IDEs #include "cunumeric/index/select.h" #include "cunumeric/pitches.h" + +namespace cunumeric { + +using namespace legate; + +template +struct SelectImplBody; + +template +struct SelectImpl { + template + void operator()(SelectArgs& args) const + { + using VAL = legate_type_of; + auto out_rect = args.out.shape(); + + Pitches pitches; + size_t volume = pitches.flatten(out_rect); + if (volume == 0) return; + + auto out = args.out.write_accessor(out_rect); + +#ifndef LEGATE_BOUNDS_CHECKS + // Check to see if this is dense or not + bool dense = out.accessor.is_dense_row_major(out_rect); +#else + // No dense execution if we're doing bounds checks + bool dense = false; +#endif + + std::vector> condlist; + for (int i = 0; i < args.inputs.size() / 2; i++) { + auto rect_c = args.inputs[i].shape(); +#ifdef DEBUG_CUNUMERIC + assert(rect_c == out_rect); +#endif + condlist.push_back(args.inputs[i].read_accessor(rect_c)); + dense = dense && condlist[i].accessor.is_dense_row_major(out_rect); + } + + std::vector> choicelist; + for (int i = args.inputs.size() / 2; i < args.inputs.size(); i++) { + auto rect_c = args.inputs[i].shape(); +#ifdef DEBUG_CUNUMERIC + assert(rect_c == out_rect); +#endif + choicelist.push_back(args.inputs[i].read_accessor(rect_c)); + dense = dense && choicelist[i - args.inputs.size() / 2].accessor.is_dense_row_major(out_rect); + } + + VAL default_value = args.default_value.value(); + SelectImplBody()( + out, condlist, choicelist, default_value, out_rect, pitches, dense); + } +}; + +template +static void select_template(TaskContext& context) +{ + SelectArgs args{context.outputs()[0], context.inputs(), context.scalars()[0]}; + double_dispatch(args.out.dim(), args.out.code(), SelectImpl{}, args); +} + +} // namespace cunumeric diff --git a/tests/integration/test_index_routines.py b/tests/integration/test_index_routines.py index 2068bb1cd..64ee32be6 100644 --- a/tests/integration/test_index_routines.py +++ b/tests/integration/test_index_routines.py @@ -268,7 +268,131 @@ def test_out_invalid_shape(self): num.choose(self.a, self.choices, out=aout) -# TODO: test select +DIM = 7 + +SELECT_SHAPES = ( + (DIM,), + (1, 1), + (1, DIM), + (DIM, 1), + (DIM, 0), + (DIM, DIM), + (1, 1, 1), + (1, 0, DIM), + (DIM, 1, 1), + (1, DIM, 1), + (1, 1, DIM), + (DIM, DIM, DIM), +) + +DEFAULTS = (0, -100, 5) + + +@pytest.mark.parametrize("size", SELECT_SHAPES) +def test_select(size): + # test with 2 conditions/choices + no default passed + arr = np.random.randint(-15, 15, size=size) + cond_np1 = arr > 1 + cond_num1 = num.array(cond_np1) + cond_np2 = arr < 0 + cond_num2 = num.array(cond_np2) + choice_np1 = arr * 10 + choice_num1 = num.array(choice_np1) + choice_np2 = arr * 2 + choice_num2 = num.array(choice_np2) + res_np = np.select( + ( + cond_np1, + cond_np2, + ), + ( + choice_np1, + choice_np2, + ), + ) + res_num = num.select( + ( + cond_num1, + cond_num2, + ), + ( + choice_num1, + choice_num2, + ), + ) + assert np.array_equal(res_np, res_num) + + # test with all False + cond_np = arr > 100 + cond_num = num.array(cond_np) + choice_np = arr * 100 + choice_num = num.array(choice_np) + res_np = np.select(cond_np, choice_np) + res_num = num.select(cond_num, choice_num) + assert np.array_equal(res_np, res_num) + + # test with all True + cond_np = arr < 100 + cond_num = num.array(cond_np) + choice_np = arr * 10 + choice_num = num.array(choice_np) + res_np = np.select(cond_np, choice_np) + res_num = num.select(cond_num, choice_num) + assert np.array_equal(res_np, res_num) + + +def test_select_maxdim(): + for ndim in range(2, LEGATE_MAX_DIM + 1): + a_shape = tuple(np.random.randint(1, 9) for i in range(ndim)) + arr = mk_seq_array(np, a_shape) + condlist_np = list() + choicelist_np = list() + condlist_num = list() + choicelist_num = list() + nlist = np.random.randint(1, 5) + for nl in range(0, nlist): + arr_con = arr > nl * 2 + arr_ch = arr * nl + condlist_np += (arr_con,) + choicelist_np += (arr_ch,) + condlist_num += (num.array(arr_con),) + choicelist_num += (num.array(arr_ch),) + res_np = np.select(condlist_np, choicelist_np) + res_num = num.select(condlist_num, choicelist_num) + assert np.array_equal(res_np, res_num) + + +@pytest.mark.parametrize("size", SELECT_SHAPES) +@pytest.mark.parametrize("default", DEFAULTS) +def test_select_default(size, default): + arr_np = np.random.randint(-5, 5, size=size) + cond_np = arr_np > 1 + cond_num = num.array(cond_np) + choice_np = arr_np**2 + choice_num = num.array(choice_np) + res_np = np.select(cond_np, choice_np, default) + res_num = num.select(cond_num, choice_num, default) + assert np.array_equal(res_np, res_num) + + +SELECT_ZERO_SHAPES = ( + (0,), + (0, 1), +) + + +@pytest.mark.parametrize("size", SELECT_ZERO_SHAPES) +def test_select_zero_shape(size): + arr_np = np.random.randint(-15, 15, size=size) + cond_np = arr_np > 1 + cond_num = num.array(cond_np) + choice_np = arr_np * 10 + choice_num = num.array(choice_np) + msg = "select with an empty condition list is not possible" + with pytest.raises(ValueError, match=msg): + np.select(cond_np, choice_np) + with pytest.raises(ValueError, match=msg): + num.select(cond_num, choice_num) def test_diagonal(): From d9b044b4c183c729b645db876cd7f5dcd1785533 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Tue, 14 Nov 2023 12:52:23 -0800 Subject: [PATCH 04/17] fixing CI --- cunumeric/deferred.py | 10 +++++++--- 1 file changed, 7 insertions(+), 3 deletions(-) diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py index 597b899ff..d92fda515 100644 --- a/cunumeric/deferred.py +++ b/cunumeric/deferred.py @@ -777,8 +777,10 @@ def _create_indexing_array( computed_key: tuple[Any, ...] if isinstance(key, NumPyThunk): computed_key = (key,) - assert isinstance(key, tuple) - computed_key = self._unpack_ellipsis(key, self.ndim) + else: + computed_key = key + assert isinstance(computed_key, tuple) + computed_key = self._unpack_ellipsis(computed_key, self.ndim) # the index where the first index_array is passed to the [] operator start_index = -1 @@ -824,7 +826,9 @@ def _create_indexing_array( ) key_transpose_indices += post_indices store = store.transpose(transpose_indices) - key = tuple(computed_key[i] for i in key_transpose_indices) + computed_key = tuple( + computed_key[i] for i in key_transpose_indices + ) shift = 0 for dim, k in enumerate(computed_key): From 579cfe1e97c90984aead6083c7a88ed0a49374d2 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Tue, 14 Nov 2023 13:10:42 -0800 Subject: [PATCH 05/17] formatting --- src/cunumeric/index/select.cu | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/src/cunumeric/index/select.cu b/src/cunumeric/index/select.cu index e99fac0e9..ceaa8259c 100644 --- a/src/cunumeric/index/select.cu +++ b/src/cunumeric/index/select.cu @@ -86,6 +86,7 @@ struct SelectImplBody { const size_t blocks = (out_size + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; auto stream = get_cached_stream(); + if (dense && (DIM <= 1 || rect.volume() == 0)) { auto cond_arr = create_buffer(condlist.size(), legate::Memory::Kind::Z_COPY_MEM); for (uint32_t idx = 0; idx < condlist.size(); ++idx) cond_arr[idx] = condlist[idx].ptr(rect); @@ -93,17 +94,19 @@ struct SelectImplBody { create_buffer(choicelist.size(), legate::Memory::Kind::Z_COPY_MEM); for (uint32_t idx = 0; idx < choicelist.size(); ++idx) choice_arr[idx] = choicelist[idx].ptr(rect); + VAL* outptr = out.ptr(rect); select_kernel_dense<<>>( outptr, narrays, cond_arr, choice_arr, default_val, out_size); - } else { + + } else { // not dense auto cond_arr = create_buffer>(condlist.size(), legate::Memory::Kind::Z_COPY_MEM); for (uint32_t idx = 0; idx < condlist.size(); ++idx) cond_arr[idx] = condlist[idx]; - auto choice_arr = create_buffer>(choicelist.size(), legate::Memory::Kind::Z_COPY_MEM); for (uint32_t idx = 0; idx < choicelist.size(); ++idx) choice_arr[idx] = choicelist[idx]; + if (out_size == 0) return; select_kernel<<>>( out, narrays, cond_arr, choice_arr, default_val, rect, pitches, out_size, rect.volume()); From d30dcd7c1698be34f43ba409b373760e1409c378 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Mon, 20 Nov 2023 09:48:55 -0800 Subject: [PATCH 06/17] fixing test_config for CI --- tests/unit/cunumeric/test_config.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/unit/cunumeric/test_config.py b/tests/unit/cunumeric/test_config.py index 5e85ccfde..5c6bc2a8a 100644 --- a/tests/unit/cunumeric/test_config.py +++ b/tests/unit/cunumeric/test_config.py @@ -143,6 +143,7 @@ def test_CuNumericOpCode() -> None: "RAND", "READ", "REPEAT", + "SELECT", "SCALAR_UNARY_RED", "SCAN_GLOBAL", "SCAN_LOCAL", From a3f1e1d8e0f68454d5b243ae5e91857d1495bb94 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Sat, 2 Dec 2023 20:58:26 -0800 Subject: [PATCH 07/17] Update src/cunumeric/index/select.h Co-authored-by: Manolis Papadakis --- src/cunumeric/index/select.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cunumeric/index/select.h b/src/cunumeric/index/select.h index 2179e12b7..e4e1e172a 100644 --- a/src/cunumeric/index/select.h +++ b/src/cunumeric/index/select.h @@ -22,7 +22,7 @@ namespace cunumeric { struct SelectArgs { const Array& out; - const std::vector& inputs; + std::vector inputs; const legate::Scalar& default_value; }; From 0e62f2b0f5d5a08bc6b569c9780efba950d5f7e3 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Sat, 2 Dec 2023 20:58:56 -0800 Subject: [PATCH 08/17] Update src/cunumeric/index/select_template.inl Co-authored-by: Manolis Papadakis --- src/cunumeric/index/select_template.inl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/cunumeric/index/select_template.inl b/src/cunumeric/index/select_template.inl index c37ff19aa..c6f71d4bd 100644 --- a/src/cunumeric/index/select_template.inl +++ b/src/cunumeric/index/select_template.inl @@ -50,7 +50,8 @@ struct SelectImpl { #endif std::vector> condlist; - for (int i = 0; i < args.inputs.size() / 2; i++) { + condlist.reserve(args.inputs.size() / 2); + for (int32_t i = 0; i < args.inputs.size() / 2; i++) { auto rect_c = args.inputs[i].shape(); #ifdef DEBUG_CUNUMERIC assert(rect_c == out_rect); From d435631e39b10940c6bce5453512820a4c505b59 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Sat, 2 Dec 2023 20:59:17 -0800 Subject: [PATCH 09/17] Update src/cunumeric/index/select_template.inl Co-authored-by: Manolis Papadakis --- src/cunumeric/index/select_template.inl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/cunumeric/index/select_template.inl b/src/cunumeric/index/select_template.inl index c6f71d4bd..1f6717fde 100644 --- a/src/cunumeric/index/select_template.inl +++ b/src/cunumeric/index/select_template.inl @@ -61,7 +61,8 @@ struct SelectImpl { } std::vector> choicelist; - for (int i = args.inputs.size() / 2; i < args.inputs.size(); i++) { + choicelist.reserve(args.inputs.size() / 2); + for (int32_t i = args.inputs.size() / 2; i < args.inputs.size(); i++) { auto rect_c = args.inputs[i].shape(); #ifdef DEBUG_CUNUMERIC assert(rect_c == out_rect); From cb5716a77237c27c43025760340d649815216ba2 Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Sat, 2 Dec 2023 20:59:39 -0800 Subject: [PATCH 10/17] Update src/cunumeric/index/select_template.inl Co-authored-by: Manolis Papadakis --- src/cunumeric/index/select_template.inl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cunumeric/index/select_template.inl b/src/cunumeric/index/select_template.inl index 1f6717fde..3231f6256 100644 --- a/src/cunumeric/index/select_template.inl +++ b/src/cunumeric/index/select_template.inl @@ -57,7 +57,7 @@ struct SelectImpl { assert(rect_c == out_rect); #endif condlist.push_back(args.inputs[i].read_accessor(rect_c)); - dense = dense && condlist[i].accessor.is_dense_row_major(out_rect); + dense = dense && condlist.back().accessor.is_dense_row_major(out_rect); } std::vector> choicelist; From 02ef7bd4a123e4356c6d448de4e814bc9cb6891b Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Sat, 2 Dec 2023 20:59:46 -0800 Subject: [PATCH 11/17] Update src/cunumeric/index/select_template.inl Co-authored-by: Manolis Papadakis --- src/cunumeric/index/select_template.inl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cunumeric/index/select_template.inl b/src/cunumeric/index/select_template.inl index 3231f6256..9ebaf19c6 100644 --- a/src/cunumeric/index/select_template.inl +++ b/src/cunumeric/index/select_template.inl @@ -68,7 +68,7 @@ struct SelectImpl { assert(rect_c == out_rect); #endif choicelist.push_back(args.inputs[i].read_accessor(rect_c)); - dense = dense && choicelist[i - args.inputs.size() / 2].accessor.is_dense_row_major(out_rect); + dense = dense && choicelist.back().accessor.is_dense_row_major(out_rect); } VAL default_value = args.default_value.value(); From 348bf7d83d5c4273b810856b2942046e67f0418b Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Sat, 2 Dec 2023 21:02:14 -0800 Subject: [PATCH 12/17] Update src/cunumeric/index/select.cu Co-authored-by: Manolis Papadakis --- src/cunumeric/index/select.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cunumeric/index/select.cu b/src/cunumeric/index/select.cu index ceaa8259c..edce0b28a 100644 --- a/src/cunumeric/index/select.cu +++ b/src/cunumeric/index/select.cu @@ -97,7 +97,7 @@ struct SelectImplBody { VAL* outptr = out.ptr(rect); select_kernel_dense<<>>( - outptr, narrays, cond_arr, choice_arr, default_val, out_size); + outptr, narrays, cond_arr, choice_arr, default_val, rect.volume()); } else { // not dense auto cond_arr = From ca3a9d0857c490ac274b325c18611dc051dbd4cb Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Sat, 2 Dec 2023 21:02:24 -0800 Subject: [PATCH 13/17] Update src/cunumeric/index/select.cu Co-authored-by: Manolis Papadakis --- src/cunumeric/index/select.cu | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/cunumeric/index/select.cu b/src/cunumeric/index/select.cu index edce0b28a..4a943eea7 100644 --- a/src/cunumeric/index/select.cu +++ b/src/cunumeric/index/select.cu @@ -83,7 +83,7 @@ struct SelectImplBody { #ifdef DEBUG_CUNUMERIC assert(narrays == choicelist.size()); #endif - const size_t blocks = (out_size + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + const size_t blocks = (rect.volume() + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; auto stream = get_cached_stream(); From 4ec0a81427ceaddb36aaf430f589e35a8bbf13fb Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Sat, 2 Dec 2023 22:32:47 -0800 Subject: [PATCH 14/17] addressing PR comments --- src/cunumeric/index/select.cu | 3 +-- src/cunumeric/index/select.h | 2 +- src/cunumeric/index/select_omp.cc | 2 +- 3 files changed, 3 insertions(+), 4 deletions(-) diff --git a/src/cunumeric/index/select.cu b/src/cunumeric/index/select.cu index 4a943eea7..01958151f 100644 --- a/src/cunumeric/index/select.cu +++ b/src/cunumeric/index/select.cu @@ -87,7 +87,7 @@ struct SelectImplBody { auto stream = get_cached_stream(); - if (dense && (DIM <= 1 || rect.volume() == 0)) { + if (dense) { auto cond_arr = create_buffer(condlist.size(), legate::Memory::Kind::Z_COPY_MEM); for (uint32_t idx = 0; idx < condlist.size(); ++idx) cond_arr[idx] = condlist[idx].ptr(rect); auto choice_arr = @@ -107,7 +107,6 @@ struct SelectImplBody { create_buffer>(choicelist.size(), legate::Memory::Kind::Z_COPY_MEM); for (uint32_t idx = 0; idx < choicelist.size(); ++idx) choice_arr[idx] = choicelist[idx]; - if (out_size == 0) return; select_kernel<<>>( out, narrays, cond_arr, choice_arr, default_val, rect, pitches, out_size, rect.volume()); } diff --git a/src/cunumeric/index/select.h b/src/cunumeric/index/select.h index e4e1e172a..2179e12b7 100644 --- a/src/cunumeric/index/select.h +++ b/src/cunumeric/index/select.h @@ -22,7 +22,7 @@ namespace cunumeric { struct SelectArgs { const Array& out; - std::vector inputs; + const std::vector& inputs; const legate::Scalar& default_value; }; diff --git a/src/cunumeric/index/select_omp.cc b/src/cunumeric/index/select_omp.cc index f015208a7..ea0aed585 100644 --- a/src/cunumeric/index/select_omp.cc +++ b/src/cunumeric/index/select_omp.cc @@ -39,7 +39,7 @@ struct SelectImplBody { assert(narrays == choicelist.size()); #endif - if (dense && DIM <= 1) { + if (dense) { auto outptr = out.ptr(rect); #pragma omp parallel for schedule(static) for (size_t idx = 0; idx < volume; ++idx) outptr[idx] = default_val; From 2b67bf95009afe1efd7c2b71869a9cc2fe04844a Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Tue, 5 Dec 2023 12:53:59 -0800 Subject: [PATCH 15/17] changing order of the for loop to simplify logic --- src/cunumeric/index/select.cc | 13 +++++-------- src/cunumeric/index/select.cu | 13 ++++++------- src/cunumeric/index/select_omp.cc | 15 +++++++-------- 3 files changed, 18 insertions(+), 23 deletions(-) diff --git a/src/cunumeric/index/select.cc b/src/cunumeric/index/select.cc index 3ebd52746..83f574676 100644 --- a/src/cunumeric/index/select.cc +++ b/src/cunumeric/index/select.cc @@ -41,23 +41,20 @@ struct SelectImplBody { if (dense) { auto outptr = out.ptr(rect); - for (size_t idx = 0; idx < volume; ++idx) outptr[idx] = default_val; for (int32_t c = (narrays - 1); c >= 0; c--) { auto condptr = condlist[c].ptr(rect); auto choiseptr = choicelist[c].ptr(rect); - for (int32_t idx = (volume - 1); idx >= 0; idx--) { + for (int32_t idx = 0; idx < volume; idx++) { + if (c == (narrays - 1)) { outptr[idx] = default_val; } if (condptr[idx]) outptr[idx] = choiseptr[idx]; } } } else { - for (size_t idx = 0; idx < volume; ++idx) { - auto p = pitches.unflatten(idx, rect.lo); - out[p] = default_val; - } for (int32_t c = (narrays - 1); c >= 0; c--) { - for (int32_t idx = (volume - 1); idx >= 0; idx--) { + for (int32_t idx = 0; idx < volume; idx++) { auto p = pitches.unflatten(idx, rect.lo); - if (condlist[c][p]) out[p] = choicelist[c][p]; + if (c == (narrays - 1)) { out[p] = default_val; }; + if (condlist[c][p]) { out[p] = choicelist[c][p]; } } } } diff --git a/src/cunumeric/index/select.cu b/src/cunumeric/index/select.cu index 01958151f..04214a7d3 100644 --- a/src/cunumeric/index/select.cu +++ b/src/cunumeric/index/select.cu @@ -51,15 +51,14 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) { const size_t tid = global_tid_1d(); if (tid >= out_size) return; - for (int32_t idx = (volume - out_size + tid); idx >= 0; idx -= out_size) { - auto p = pitches.unflatten(idx, rect.lo); - out[p] = default_val; - } - __syncthreads(); for (int32_t c = (narrays - 1); c >= 0; c--) { - for (int32_t idx = (volume - out_size + tid); idx >= 0; idx -= out_size) { + for (int32_t idx = 0; idx <= (volume - out_size + tid); idx += out_size) { auto p = pitches.unflatten(idx, rect.lo); - if (condlist[c][p]) { out[p] = choicelist[c][p]; } + if (condlist[c][p]) { + out[p] = choicelist[c][p]; + } else if (c == (narrays - 1)) { + out[p] = default_val; + } } } } diff --git a/src/cunumeric/index/select_omp.cc b/src/cunumeric/index/select_omp.cc index ea0aed585..a6bbcdbb4 100644 --- a/src/cunumeric/index/select_omp.cc +++ b/src/cunumeric/index/select_omp.cc @@ -47,23 +47,22 @@ struct SelectImplBody { auto condptr = condlist[c].ptr(rect); auto choiseptr = choicelist[c].ptr(rect); #pragma omp parallel for schedule(static) - for (int32_t idx = (volume - 1); idx >= 0; idx--) { + for (int32_t idx = 0; idx < volume; idx++) { if (condptr[idx]) outptr[idx] = choiseptr[idx]; } } } else { -#pragma omp parallel for schedule(static) - for (size_t idx = 0; idx < volume; ++idx) { - auto p = pitches.unflatten(idx, rect.lo); - out[p] = default_val; - } const size_t out_size = rect.hi[0] - rect.lo[0] + 1; for (int32_t c = (narrays - 1); c >= 0; c--) { #pragma omp parallel for schedule(static) for (int32_t out_idx = 0; out_idx <= out_size; out_idx++) { - for (int32_t idx = (volume - out_size + out_idx); idx >= 0; idx -= out_size) { + for (int32_t idx = 0; idx <= (volume - out_size + out_idx); idx += out_size) { auto p = pitches.unflatten(idx, rect.lo); - if (condlist[c][p]) out[p] = choicelist[c][p]; + if (condlist[c][p]) + out[p] = choicelist[c][p]; + else if (c == (narrays - 1)) { + out[p] = default_val; + } } } } From 738af4ac14dca25c499860a5898a8a9fb315ca2b Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Wed, 6 Dec 2023 16:12:57 -0800 Subject: [PATCH 16/17] addressing PR comments --- src/cunumeric/index/select.cc | 7 +- src/cunumeric/index/select.cu | 19 ++--- src/cunumeric/index/select2.cu | 122 ++++++++++++++++++++++++++++++ src/cunumeric/index/select_omp.cc | 20 +++-- 4 files changed, 142 insertions(+), 26 deletions(-) create mode 100644 src/cunumeric/index/select2.cu diff --git a/src/cunumeric/index/select.cc b/src/cunumeric/index/select.cc index 83f574676..d304ae179 100644 --- a/src/cunumeric/index/select.cc +++ b/src/cunumeric/index/select.cc @@ -41,19 +41,22 @@ struct SelectImplBody { if (dense) { auto outptr = out.ptr(rect); + for (size_t idx = 0; idx < volume; ++idx) { outptr[idx] = default_val; } for (int32_t c = (narrays - 1); c >= 0; c--) { auto condptr = condlist[c].ptr(rect); auto choiseptr = choicelist[c].ptr(rect); for (int32_t idx = 0; idx < volume; idx++) { - if (c == (narrays - 1)) { outptr[idx] = default_val; } if (condptr[idx]) outptr[idx] = choiseptr[idx]; } } } else { + for (size_t idx = 0; idx < volume; ++idx) { + auto p = pitches.unflatten(idx, rect.lo); + out[p] = default_val; + } for (int32_t c = (narrays - 1); c >= 0; c--) { for (int32_t idx = 0; idx < volume; idx++) { auto p = pitches.unflatten(idx, rect.lo); - if (c == (narrays - 1)) { out[p] = default_val; }; if (condlist[c][p]) { out[p] = choicelist[c][p]; } } } diff --git a/src/cunumeric/index/select.cu b/src/cunumeric/index/select.cu index 04214a7d3..5a11c44ca 100644 --- a/src/cunumeric/index/select.cu +++ b/src/cunumeric/index/select.cu @@ -46,20 +46,14 @@ __global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) VAL default_val, const Rect rect, const Pitches pitches, - int out_size, int volume) { const size_t tid = global_tid_1d(); - if (tid >= out_size) return; + if (tid >= volume) return; + auto p = pitches.unflatten(tid, rect.lo); + out[p] = default_val; for (int32_t c = (narrays - 1); c >= 0; c--) { - for (int32_t idx = 0; idx <= (volume - out_size + tid); idx += out_size) { - auto p = pitches.unflatten(idx, rect.lo); - if (condlist[c][p]) { - out[p] = choicelist[c][p]; - } else if (c == (narrays - 1)) { - out[p] = default_val; - } - } + if (condlist[c][p]) { out[p] = choicelist[c][p]; } } } @@ -77,8 +71,7 @@ struct SelectImplBody { const Pitches& pitches, bool dense) const { - const size_t out_size = rect.hi[0] - rect.lo[0] + 1; - uint32_t narrays = condlist.size(); + uint32_t narrays = condlist.size(); #ifdef DEBUG_CUNUMERIC assert(narrays == choicelist.size()); #endif @@ -107,7 +100,7 @@ struct SelectImplBody { for (uint32_t idx = 0; idx < choicelist.size(); ++idx) choice_arr[idx] = choicelist[idx]; select_kernel<<>>( - out, narrays, cond_arr, choice_arr, default_val, rect, pitches, out_size, rect.volume()); + out, narrays, cond_arr, choice_arr, default_val, rect, pitches, rect.volume()); } CHECK_CUDA_STREAM(stream); diff --git a/src/cunumeric/index/select2.cu b/src/cunumeric/index/select2.cu new file mode 100644 index 000000000..fd5076052 --- /dev/null +++ b/src/cunumeric/index/select2.cu @@ -0,0 +1,122 @@ +/* Copyright 2023 NVIDIA Corporation + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * + */ + +#include "cunumeric/index/select.h" +#include "cunumeric/index/select_template.inl" +#include "cunumeric/cuda_help.h" + +namespace cunumeric { + +template +__global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) + select_kernel_dense(VAL* outptr, + uint32_t narrays, + legate::Buffer condlist, + legate::Buffer choicelist, + VAL default_val, + int volume) +{ + const size_t idx = global_tid_1d(); + if (idx >= volume) return; + outptr[idx] = default_val; + for (int32_t c = (narrays - 1); c >= 0; c--) { + if (condlist[c][idx]) { outptr[idx] = choicelist[c][idx]; } + } +} + +template +__global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) + select_kernel(const AccessorWO out, + uint32_t narrays, + const legate::Buffer> condlist, + const legate::Buffer> choicelist, + VAL default_val, + const Rect rect, + const Pitches pitches, + int out_size, + int volume) +{ + const size_t tid = global_tid_1d(); + if (tid >= out_size) return; + for (int32_t idx = 0; idx <= (volume - out_size + tid); idx += out_size) { + auto p = pitches.unflatten(idx, rect.lo); + out[p] = default_val; + } + for (int32_t c = (narrays - 1); c >= 0; c--) { + for (int32_t idx = 0; idx <= (volume - out_size + tid); idx += out_size) { + auto p = pitches.unflatten(idx, rect.lo); + if (condlist[c][p]) { out[p] = choicelist[c][p]; } + } + } +} + +using namespace legate; + +template +struct SelectImplBody { + using VAL = legate_type_of; + + void operator()(const AccessorWO& out, + const std::vector>& condlist, + const std::vector>& choicelist, + VAL default_val, + const Rect& rect, + const Pitches& pitches, + bool dense) const + { + const size_t out_size = rect.hi[0] - rect.lo[0] + 1; + uint32_t narrays = condlist.size(); +#ifdef DEBUG_CUNUMERIC + assert(narrays == choicelist.size()); +#endif + const size_t blocks = (rect.volume() + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; + + auto stream = get_cached_stream(); + + if (dense) { + auto cond_arr = create_buffer(condlist.size(), legate::Memory::Kind::Z_COPY_MEM); + for (uint32_t idx = 0; idx < condlist.size(); ++idx) cond_arr[idx] = condlist[idx].ptr(rect); + auto choice_arr = + create_buffer(choicelist.size(), legate::Memory::Kind::Z_COPY_MEM); + for (uint32_t idx = 0; idx < choicelist.size(); ++idx) + choice_arr[idx] = choicelist[idx].ptr(rect); + + VAL* outptr = out.ptr(rect); + select_kernel_dense<<>>( + outptr, narrays, cond_arr, choice_arr, default_val, rect.volume()); + + } else { // not dense + auto cond_arr = + create_buffer>(condlist.size(), legate::Memory::Kind::Z_COPY_MEM); + for (uint32_t idx = 0; idx < condlist.size(); ++idx) cond_arr[idx] = condlist[idx]; + auto choice_arr = + create_buffer>(choicelist.size(), legate::Memory::Kind::Z_COPY_MEM); + for (uint32_t idx = 0; idx < choicelist.size(); ++idx) choice_arr[idx] = choicelist[idx]; + + select_kernel<<>>( + out, narrays, cond_arr, choice_arr, default_val, rect, pitches, out_size, rect.volume()); + } + + CHECK_CUDA_STREAM(stream); + } +}; + +/*static*/ void SelectTask::gpu_variant(TaskContext& context) +{ + select_template(context); +} + +} // namespace cunumeric diff --git a/src/cunumeric/index/select_omp.cc b/src/cunumeric/index/select_omp.cc index a6bbcdbb4..98a9854cd 100644 --- a/src/cunumeric/index/select_omp.cc +++ b/src/cunumeric/index/select_omp.cc @@ -42,7 +42,7 @@ struct SelectImplBody { if (dense) { auto outptr = out.ptr(rect); #pragma omp parallel for schedule(static) - for (size_t idx = 0; idx < volume; ++idx) outptr[idx] = default_val; + for (size_t idx = 0; idx < volume; ++idx) { outptr[idx] = default_val; } for (int32_t c = (narrays - 1); c >= 0; c--) { auto condptr = condlist[c].ptr(rect); auto choiseptr = choicelist[c].ptr(rect); @@ -52,18 +52,16 @@ struct SelectImplBody { } } } else { - const size_t out_size = rect.hi[0] - rect.lo[0] + 1; +#pragma omp parallel for schedule(static) + for (size_t idx = 0; idx < volume; ++idx) { + auto p = pitches.unflatten(idx, rect.lo); + out[p] = default_val; + } for (int32_t c = (narrays - 1); c >= 0; c--) { #pragma omp parallel for schedule(static) - for (int32_t out_idx = 0; out_idx <= out_size; out_idx++) { - for (int32_t idx = 0; idx <= (volume - out_size + out_idx); idx += out_size) { - auto p = pitches.unflatten(idx, rect.lo); - if (condlist[c][p]) - out[p] = choicelist[c][p]; - else if (c == (narrays - 1)) { - out[p] = default_val; - } - } + for (int32_t idx = 0; idx < volume; idx++) { + auto p = pitches.unflatten(idx, rect.lo); + if (condlist[c][p]) { out[p] = choicelist[c][p]; } } } } From c9867199a13aaaadacbf98d9642b9ddced909f9f Mon Sep 17 00:00:00 2001 From: Irina Demeshko Date: Thu, 7 Dec 2023 09:21:54 -0800 Subject: [PATCH 17/17] removing tmp file --- src/cunumeric/index/select2.cu | 122 --------------------------------- 1 file changed, 122 deletions(-) delete mode 100644 src/cunumeric/index/select2.cu diff --git a/src/cunumeric/index/select2.cu b/src/cunumeric/index/select2.cu deleted file mode 100644 index fd5076052..000000000 --- a/src/cunumeric/index/select2.cu +++ /dev/null @@ -1,122 +0,0 @@ -/* Copyright 2023 NVIDIA Corporation - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - * - */ - -#include "cunumeric/index/select.h" -#include "cunumeric/index/select_template.inl" -#include "cunumeric/cuda_help.h" - -namespace cunumeric { - -template -__global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) - select_kernel_dense(VAL* outptr, - uint32_t narrays, - legate::Buffer condlist, - legate::Buffer choicelist, - VAL default_val, - int volume) -{ - const size_t idx = global_tid_1d(); - if (idx >= volume) return; - outptr[idx] = default_val; - for (int32_t c = (narrays - 1); c >= 0; c--) { - if (condlist[c][idx]) { outptr[idx] = choicelist[c][idx]; } - } -} - -template -__global__ static void __launch_bounds__(THREADS_PER_BLOCK, MIN_CTAS_PER_SM) - select_kernel(const AccessorWO out, - uint32_t narrays, - const legate::Buffer> condlist, - const legate::Buffer> choicelist, - VAL default_val, - const Rect rect, - const Pitches pitches, - int out_size, - int volume) -{ - const size_t tid = global_tid_1d(); - if (tid >= out_size) return; - for (int32_t idx = 0; idx <= (volume - out_size + tid); idx += out_size) { - auto p = pitches.unflatten(idx, rect.lo); - out[p] = default_val; - } - for (int32_t c = (narrays - 1); c >= 0; c--) { - for (int32_t idx = 0; idx <= (volume - out_size + tid); idx += out_size) { - auto p = pitches.unflatten(idx, rect.lo); - if (condlist[c][p]) { out[p] = choicelist[c][p]; } - } - } -} - -using namespace legate; - -template -struct SelectImplBody { - using VAL = legate_type_of; - - void operator()(const AccessorWO& out, - const std::vector>& condlist, - const std::vector>& choicelist, - VAL default_val, - const Rect& rect, - const Pitches& pitches, - bool dense) const - { - const size_t out_size = rect.hi[0] - rect.lo[0] + 1; - uint32_t narrays = condlist.size(); -#ifdef DEBUG_CUNUMERIC - assert(narrays == choicelist.size()); -#endif - const size_t blocks = (rect.volume() + THREADS_PER_BLOCK - 1) / THREADS_PER_BLOCK; - - auto stream = get_cached_stream(); - - if (dense) { - auto cond_arr = create_buffer(condlist.size(), legate::Memory::Kind::Z_COPY_MEM); - for (uint32_t idx = 0; idx < condlist.size(); ++idx) cond_arr[idx] = condlist[idx].ptr(rect); - auto choice_arr = - create_buffer(choicelist.size(), legate::Memory::Kind::Z_COPY_MEM); - for (uint32_t idx = 0; idx < choicelist.size(); ++idx) - choice_arr[idx] = choicelist[idx].ptr(rect); - - VAL* outptr = out.ptr(rect); - select_kernel_dense<<>>( - outptr, narrays, cond_arr, choice_arr, default_val, rect.volume()); - - } else { // not dense - auto cond_arr = - create_buffer>(condlist.size(), legate::Memory::Kind::Z_COPY_MEM); - for (uint32_t idx = 0; idx < condlist.size(); ++idx) cond_arr[idx] = condlist[idx]; - auto choice_arr = - create_buffer>(choicelist.size(), legate::Memory::Kind::Z_COPY_MEM); - for (uint32_t idx = 0; idx < choicelist.size(); ++idx) choice_arr[idx] = choicelist[idx]; - - select_kernel<<>>( - out, narrays, cond_arr, choice_arr, default_val, rect, pitches, out_size, rect.volume()); - } - - CHECK_CUDA_STREAM(stream); - } -}; - -/*static*/ void SelectTask::gpu_variant(TaskContext& context) -{ - select_template(context); -} - -} // namespace cunumeric