-
Notifications
You must be signed in to change notification settings - Fork 79
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support np.select operation #1066
Support np.select operation #1066
Conversation
@manopapad Below is a diff that fixes the mypy issues for me. I am not sure the proper fix for the Edit: FYI I was also just speculating about handling the None axes case this way
since diff --git a/cunumeric/config.py b/cunumeric/config.py
index 5020f5a4..f01cbb6b 100644
--- a/cunumeric/config.py
+++ b/cunumeric/config.py
@@ -364,7 +364,7 @@ class CuNumericOpCode(IntEnum):
SCAN_GLOBAL = _cunumeric.CUNUMERIC_SCAN_GLOBAL
SCAN_LOCAL = _cunumeric.CUNUMERIC_SCAN_LOCAL
SEARCHSORTED = _cunumeric.CUNUMERIC_SEARCHSORTED
- SELECT = _cunumeric.SELECT
+ SELECT = _cunumeric.CUNUMERIC_SELECT
SOLVE = _cunumeric.CUNUMERIC_SOLVE
SORT = _cunumeric.CUNUMERIC_SORT
SYRK = _cunumeric.CUNUMERIC_SYRK
diff --git a/cunumeric/deferred.py b/cunumeric/deferred.py
index fd3844f5..597b899f 100644
--- a/cunumeric/deferred.py
+++ b/cunumeric/deferred.py
@@ -36,6 +36,7 @@ from typing import (
import legate.core.types as ty
import numpy as np
from legate.core import Annotation, Future, ReductionOp, Store
+from legate.core.store import RegionField
from legate.core.utils import OrderedSet
from numpy.core.numeric import ( # type: ignore [attr-defined]
normalize_axis_tuple,
@@ -270,11 +271,13 @@ class DeferredArray(NumPyThunk):
return f"DeferredArray(base: {self.base})"
@property
- def storage(self) -> Union[Future, tuple[Region, FieldID]]:
+ def storage(self) -> Union[Future, tuple[Region, Union[int, FieldID]]]:
storage = self.base.storage
if self.base.kind == Future:
+ assert isinstance(storage, Future)
return storage
else:
+ assert isinstance(storage, RegionField)
return (storage.region, storage.field.field_id)
@property
@@ -402,6 +405,7 @@ class DeferredArray(NumPyThunk):
def get_scalar_array(self) -> npt.NDArray[Any]:
assert self.scalar
+ assert isinstance(self.base.storage, Future)
buf = self.base.storage.get_buffer(self.dtype.itemsize)
result = np.frombuffer(buf, dtype=self.dtype, count=1)
return result.reshape(())
@@ -770,10 +774,11 @@ class DeferredArray(NumPyThunk):
store = self.base
rhs = self
+ computed_key: tuple[Any, ...]
if isinstance(key, NumPyThunk):
- key = (key,)
+ computed_key = (key,)
assert isinstance(key, tuple)
- key = self._unpack_ellipsis(key, self.ndim)
+ computed_key = self._unpack_ellipsis(key, self.ndim)
# the index where the first index_array is passed to the [] operator
start_index = -1
@@ -788,7 +793,7 @@ class DeferredArray(NumPyThunk):
tuple_of_arrays: tuple[Any, ...] = ()
# First, we need to check if transpose is needed
- for dim, k in enumerate(key):
+ for dim, k in enumerate(computed_key):
if np.isscalar(k) or isinstance(k, NumPyThunk):
if start_index == -1:
start_index = dim
@@ -813,17 +818,19 @@ class DeferredArray(NumPyThunk):
)
transpose_indices += post_indices
post_indices = tuple(
- i for i in range(len(key)) if i not in key_transpose_indices
+ i
+ for i in range(len(computed_key))
+ if i not in key_transpose_indices
)
key_transpose_indices += post_indices
store = store.transpose(transpose_indices)
- key = tuple(key[i] for i in key_transpose_indices)
+ key = tuple(computed_key[i] for i in key_transpose_indices)
shift = 0
- for dim, k in enumerate(key):
+ for dim, k in enumerate(computed_key):
if np.isscalar(k):
if k < 0: # type: ignore [operator]
- k += store.shape[dim + shift]
+ k += store.shape[dim + shift] # type: ignore [operator]
store = store.project(dim + shift, k)
shift -= 1
elif k is np.newaxis:
@@ -831,7 +838,7 @@ class DeferredArray(NumPyThunk):
elif isinstance(k, slice):
k, store = self._slice_store(k, store, dim + shift)
elif isinstance(k, NumPyThunk):
- if not isinstance(key, DeferredArray):
+ if not isinstance(computed_key, DeferredArray):
k = self.runtime.to_deferred_array(k)
if k.dtype == bool:
for i in range(k.ndim):
@@ -900,7 +907,7 @@ class DeferredArray(NumPyThunk):
k, store = self._slice_store(k, store, dim + shift)
elif np.isscalar(k):
if k < 0: # type: ignore [operator]
- k += store.shape[dim + shift]
+ k += store.shape[dim + shift] # type: ignore [operator]
store = store.project(dim + shift, k)
shift -= 1
else:
@@ -1329,10 +1336,9 @@ class DeferredArray(NumPyThunk):
dims = list(range(self.ndim))
dims[axis1], dims[axis2] = dims[axis2], dims[axis1]
- result = self.base.transpose(dims)
- result = DeferredArray(self.runtime, result)
+ result = self.base.transpose(tuple(dims))
- return result
+ return DeferredArray(self.runtime, result)
# Convert the source array to the destination array
@auto_convert("rhs")
@@ -1738,8 +1744,8 @@ class DeferredArray(NumPyThunk):
out_arr = self.base
# broadcast input array and all choices arrays to the same shape
- index = index_arr._broadcast(out_arr.shape)
- ch_tuple = tuple(c._broadcast(out_arr.shape) for c in ch_def)
+ index = index_arr._broadcast(out_arr.shape.extents)
+ ch_tuple = tuple(c._broadcast(out_arr.shape.extents) for c in ch_def)
task = self.context.create_auto_task(CuNumericOpCode.CHOOSE)
task.add_output(out_arr)
@@ -1985,9 +1991,9 @@ class DeferredArray(NumPyThunk):
def transpose(
self, axes: Union[None, tuple[int, ...], list[int]]
) -> DeferredArray:
- result = self.base.transpose(axes)
- result = DeferredArray(self.runtime, result)
- return result
+ computed_axes = tuple(axes) if axes is not None else ()
+ result = self.base.transpose(computed_axes)
+ return DeferredArray(self.runtime, result)
@auto_convert("rhs")
def trilu(self, rhs: Any, k: int, lower: bool) -> None:
diff --git a/cunumeric/eager.py b/cunumeric/eager.py
index 1e913739..56a056b6 100644
--- a/cunumeric/eager.py
+++ b/cunumeric/eager.py
@@ -235,7 +235,7 @@ class EagerArray(NumPyThunk):
self.escaped = False
@property
- def storage(self) -> Union[Future, tuple[Region, FieldID]]:
+ def storage(self) -> Union[Future, tuple[Region, Union[int, FieldID]]]:
if self.deferred is None:
self.to_deferred_array()
diff --git a/cunumeric/thunk.py b/cunumeric/thunk.py
index 0a831bce..68aafb6c 100644
--- a/cunumeric/thunk.py
+++ b/cunumeric/thunk.py
@@ -74,7 +74,7 @@ class NumPyThunk(ABC):
# Abstract methods
@abstractproperty
- def storage(self) -> Union[Future, tuple[Region, FieldID]]:
+ def storage(self) -> Union[Future, tuple[Region, Union[int, FieldID]]]:
"""Return the Legion storage primitive for this NumPy thunk"""
... |
@bryevdv : thank you for the fix! I am taking over the work on finishing np.select . |
I would suggest keeping the same branch for expedience. You should be able to push to it I think |
Co-authored-by: Manolis Papadakis <[email protected]>
Co-authored-by: Manolis Papadakis <[email protected]>
Co-authored-by: Manolis Papadakis <[email protected]>
Co-authored-by: Manolis Papadakis <[email protected]>
Co-authored-by: Manolis Papadakis <[email protected]>
Co-authored-by: Manolis Papadakis <[email protected]>
Co-authored-by: Manolis Papadakis <[email protected]>
/ok to test |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, I think I tormented you enough. I think one file was miscommited, but otherwise this is ready to go.
Note that you'll need to approve the PR (because I opened it I can't approve it myself) before you can merge it.
src/cunumeric/index/select2.cu
Outdated
@@ -0,0 +1,122 @@ | |||
/* Copyright 2023 NVIDIA Corporation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Was this file added by mistake?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, sorry
Python bits are done and C++ stubs are in place, but need to be filled in. Probably a good idea to copy from
choose
operation. Passing to @ipdemes to take a look.@bryevdv I am seeing mypy errors on my local machine, that don't appear related to my changes. Possibly a
mypy
version bump is causing this, can you confirm that you're seeing them too, maybe suggest fixes?mypy errors