From 5769509676f8e9d82a2d3bac3557a7d0d4a6b6d7 Mon Sep 17 00:00:00 2001 From: Manolis Papadakis Date: Tue, 12 Sep 2023 13:29:38 -0700 Subject: [PATCH] Reject types that DeferredArray doesn't support IMHO there is little value in pretending that we support these types. We currently accept them by falling back to EagerArray, and that works for a little bit, but it is possible that later we will need to convert to a truly Legion-backed implementation, in which case the failure will pop up there. Instead we should just reject them from the start. --- cunumeric/array.py | 27 ++++++++++--------------- cunumeric/eager.py | 4 ++-- cunumeric/module.py | 4 +++- cunumeric/runtime.py | 7 +------ cunumeric/utils.py | 13 +++++++++--- tests/integration/test_array.py | 23 +++++++++++++++++++-- tests/integration/test_astype.py | 8 -------- tests/integration/test_atleast_nd.py | 15 -------------- tests/integration/test_fill.py | 1 + tests/integration/test_fill_diagonal.py | 2 +- tests/integration/test_logic.py | 2 +- tests/integration/test_matmul.py | 8 +++----- tests/integration/test_repeat.py | 16 --------------- tests/integration/test_sort.py | 18 ----------------- tests/integration/test_split.py | 22 -------------------- 15 files changed, 54 insertions(+), 116 deletions(-) diff --git a/cunumeric/array.py b/cunumeric/array.py index b9db174e2..cce24e492 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -310,14 +310,9 @@ def __init__( if isinstance(inp, ndarray) ] core_dtype = to_core_dtype(dtype) - if core_dtype is not None: - self._thunk = runtime.create_empty_thunk( - sanitized_shape, core_dtype, inputs - ) - else: - self._thunk = runtime.create_eager_thunk( - sanitized_shape, dtype - ) + self._thunk = runtime.create_empty_thunk( + sanitized_shape, core_dtype, inputs + ) else: self._thunk = thunk self._legate_data: Union[dict[str, Any], None] = None @@ -1665,7 +1660,7 @@ def __rxor__(self, lhs: Any) -> ndarray: # __setattr__ @add_boilerplate("value") - def __setitem__(self, key: Any, value: Any) -> None: + def __setitem__(self, key: Any, value: ndarray) -> None: """__setitem__(key, value, /) Set ``self[key]=value``. @@ -2680,7 +2675,7 @@ def trace( return res @add_boilerplate("rhs") - def dot(self, rhs: Any, out: Union[ndarray, None] = None) -> ndarray: + def dot(self, rhs: ndarray, out: Union[ndarray, None] = None) -> ndarray: """a.dot(rhs, out=None) Return the dot product of this array with ``rhs``. @@ -4241,8 +4236,8 @@ def _perform_unary_reduction( def _perform_binary_reduction( cls, op: BinaryOpCode, - one: Any, - two: Any, + one: ndarray, + two: ndarray, dtype: np.dtype[Any], extra_args: Union[tuple[Any, ...], None] = None, ) -> ndarray: @@ -4258,14 +4253,14 @@ def _perform_binary_reduction( broadcast = None common_type = cls.find_common_type(one, two) - one = one._maybe_convert(common_type, args)._thunk - two = two._maybe_convert(common_type, args)._thunk + one_thunk = one._maybe_convert(common_type, args)._thunk + two_thunk = two._maybe_convert(common_type, args)._thunk dst = ndarray(shape=(), dtype=dtype, inputs=args) dst._thunk.binary_reduction( op, - one, - two, + one_thunk, + two_thunk, broadcast, extra_args, ) diff --git a/cunumeric/eager.py b/cunumeric/eager.py index bc83d49eb..680f1b5a1 100644 --- a/cunumeric/eager.py +++ b/cunumeric/eager.py @@ -42,7 +42,7 @@ ) from .deferred import DeferredArray from .thunk import NumPyThunk -from .utils import is_advanced_indexing +from .utils import is_advanced_indexing, is_supported_type if TYPE_CHECKING: import numpy.typing as npt @@ -305,7 +305,7 @@ def to_deferred_array(self) -> DeferredArray: # or whether we need to go up the tree to have it made if self.deferred is None: if self.parent is None: - assert self.runtime.is_supported_type(self.array.dtype) + assert is_supported_type(self.array.dtype) # We are at the root of the tree so we need to # actually make a DeferredArray to use if self.array.size == 1: diff --git a/cunumeric/module.py b/cunumeric/module.py index ad93b675e..c676dc02c 100644 --- a/cunumeric/module.py +++ b/cunumeric/module.py @@ -7136,7 +7136,9 @@ def bincount( # Handle the special case of 0-D array if weights is None: out = zeros((minlength,), dtype=np.dtype(np.int64)) - out[x[0]] = 1 + # TODO: Remove this "type: ignore" once @add_boilerplate can + # propagate "ndarray -> ndarray | npt.ArrayLike" in wrapped sigs + out[x[0]] = 1 # type: ignore [assignment] else: out = zeros((minlength,), dtype=weights.dtype) index = x[0] diff --git a/cunumeric/runtime.py b/cunumeric/runtime.py index 689a38423..a07c0847b 100644 --- a/cunumeric/runtime.py +++ b/cunumeric/runtime.py @@ -177,7 +177,6 @@ def create_wrapped_scalar( future = self.create_scalar(array, shape) assert all(extent == 1 for extent in shape) core_dtype = to_core_dtype(dtype) - assert core_dtype is not None store = self.legate_context.create_store( core_dtype, shape=shape, @@ -260,9 +259,6 @@ def get_next_random_epoch(self) -> int: self.current_random_epoch += 1 return result - def is_supported_type(self, dtype: Union[str, np.dtype[Any]]) -> bool: - return to_core_dtype(dtype) is not None - def get_numpy_thunk( self, obj: Union[ndarray, npt.NDArray[Any]], @@ -416,7 +412,7 @@ def find_or_create_array_thunk( # Check to see if it is a type that we support for doing deferred # execution and big enough to be worth off-loading onto Legion dtype = to_core_dtype(array.dtype) - if dtype is not None and ( + if ( defer or not self.is_eager_shape(array.shape) or self.has_external_attachment(array) @@ -446,7 +442,6 @@ def find_or_create_array_thunk( numpy_array=array if share else None, ) - assert not defer # Make this into an eager evaluated thunk return EagerArray(self, array) diff --git a/cunumeric/utils.py b/cunumeric/utils.py index 0586bb8f3..55a9b8c1e 100644 --- a/cunumeric/utils.py +++ b/cunumeric/utils.py @@ -18,7 +18,7 @@ from functools import reduce from string import ascii_lowercase, ascii_uppercase from types import FrameType -from typing import Any, Callable, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, List, Sequence, Tuple, Union import legate.core.types as ty import numpy as np @@ -43,8 +43,15 @@ } -def to_core_dtype(dtype: Union[str, np.dtype[Any]]) -> Optional[ty.Dtype]: - return SUPPORTED_DTYPES.get(np.dtype(dtype)) +def is_supported_type(dtype: Union[str, np.dtype[Any]]) -> bool: + return np.dtype(dtype) in SUPPORTED_DTYPES + + +def to_core_dtype(dtype: Union[str, np.dtype[Any]]) -> ty.Dtype: + core_dtype = SUPPORTED_DTYPES.get(np.dtype(dtype)) + if core_dtype is None: + raise TypeError(f"cuNumeric does not support dtype={dtype}") + return core_dtype def is_advanced_indexing(key: Any) -> bool: diff --git a/tests/integration/test_array.py b/tests/integration/test_array.py index 7adb8b093..6e03e98df 100755 --- a/tests/integration/test_array.py +++ b/tests/integration/test_array.py @@ -36,6 +36,13 @@ ), ) +UNSUPPORTED_OBJECTS = ( + None, + "somestr", + ["one", "two"], + [("name", "S10"), ("height", float), ("age", int)], +) + def strict_type_equal(a, b): return np.array_equal(a, b) and a.dtype == b.dtype @@ -43,7 +50,7 @@ def strict_type_equal(a, b): @pytest.mark.parametrize( "obj", - (None,) + SCALARS + ARRAYS, + SCALARS + ARRAYS, ids=lambda obj: f"(object={obj})", ) def test_array_basic(obj): @@ -52,6 +59,12 @@ def test_array_basic(obj): assert strict_type_equal(res_np, res_num) +@pytest.mark.parametrize("obj", UNSUPPORTED_OBJECTS) +def test_array_unsupported(obj): + with pytest.raises(TypeError, match="cuNumeric does not support dtype"): + num.array(obj) + + def test_array_ndarray(): obj = [[1, 2], [3, 4]] res_np = np.array(np.array(obj)) @@ -129,7 +142,7 @@ def test_invalid_dtype(self, obj, dtype): @pytest.mark.parametrize( "obj", - (None,) + SCALARS + ARRAYS, + SCALARS + ARRAYS, ids=lambda obj: f"(object={obj})", ) def test_asarray_basic(obj): @@ -138,6 +151,12 @@ def test_asarray_basic(obj): assert strict_type_equal(res_np, res_num) +@pytest.mark.parametrize("obj", UNSUPPORTED_OBJECTS) +def test_asarray_unsupported(obj): + with pytest.raises(TypeError, match="cuNumeric does not support dtype"): + num.array(obj) + + def test_asarray_ndarray(): obj = [[1, 2], [3, 4]] res_np = np.asarray(np.array(obj)) diff --git a/tests/integration/test_astype.py b/tests/integration/test_astype.py index 5a54a7789..725bab21b 100644 --- a/tests/integration/test_astype.py +++ b/tests/integration/test_astype.py @@ -46,14 +46,6 @@ def to_dtype(s): return str(np.dtype(s)) -def test_none(): - arr = None - in_np = num.array(arr) - msg = r"NoneType" - with pytest.raises(TypeError, match=msg): - in_np.astype("b") - - @pytest.mark.parametrize("src_dtype", ALL_TYPES, ids=to_dtype) def test_empty(src_dtype): arr = [] diff --git a/tests/integration/test_atleast_nd.py b/tests/integration/test_atleast_nd.py index 3946cb92f..da67e2de9 100644 --- a/tests/integration/test_atleast_nd.py +++ b/tests/integration/test_atleast_nd.py @@ -43,11 +43,6 @@ def test_atleast_1d_scalar(): assert np.array_equal(np.atleast_1d(a), num.atleast_1d(a)) -def test_atleast_1d_none(): - a = None - assert np.array_equal(np.atleast_1d(a), num.atleast_1d(a)) - - @pytest.mark.parametrize("size", SIZE_CASES, ids=str) def test_atleast_2d(size): a = [np.arange(np.prod(size)).reshape(size)] @@ -60,11 +55,6 @@ def test_atleast_2d_scalar(): assert np.array_equal(np.atleast_2d(a), num.atleast_2d(a)) -def test_atleast_2d_none(): - a = None - assert np.array_equal(np.atleast_2d(a), num.atleast_2d(a)) - - @pytest.mark.parametrize("size", SIZE_CASES, ids=str) def test_atleast_3d(size): a = [np.arange(np.prod(size)).reshape(size)] @@ -77,11 +67,6 @@ def test_atleast_3d_scalar(): assert np.array_equal(np.atleast_2d(a), num.atleast_2d(a)) -def test_atleast_3d_none(): - a = None - assert np.array_equal(np.atleast_2d(a), num.atleast_2d(a)) - - # test to run atleast_nd w/ list of arrays @pytest.mark.parametrize("dim", range(1, 4)) def test_atleast_nd(dim): diff --git a/tests/integration/test_fill.py b/tests/integration/test_fill.py index d77ae8e07..134e209f2 100644 --- a/tests/integration/test_fill.py +++ b/tests/integration/test_fill.py @@ -124,6 +124,7 @@ def test_fill_int_to_float() -> None: assert np.array_equal(a_np, a_num) +@pytest.mark.xfail def test_fill_string() -> None: a_list = ["hello", "hi"] a_np = np.array(a_list) diff --git a/tests/integration/test_fill_diagonal.py b/tests/integration/test_fill_diagonal.py index fc18783d3..f50309607 100644 --- a/tests/integration/test_fill_diagonal.py +++ b/tests/integration/test_fill_diagonal.py @@ -97,7 +97,7 @@ def test_dimension_mismatch(self): with pytest.raises(expected_exc): num.fill_diagonal(arr, 5) - @pytest.mark.parametrize("arr", (None, -3, [0], (5))) + @pytest.mark.parametrize("arr", (-3, [0], (5))) def test_arr_invalid(self, arr): arr_np = np.array(arr) arr_num = num.array(arr) diff --git a/tests/integration/test_logic.py b/tests/integration/test_logic.py index 4d15524b1..f969eb168 100644 --- a/tests/integration/test_logic.py +++ b/tests/integration/test_logic.py @@ -67,7 +67,7 @@ def test_out_invalid_shape(self, func_name): func_num(x, out=res_num) -SCALARS = (pytest.param("a string", marks=pytest.mark.xfail), None, False) +SCALARS = (pytest.param("a string", marks=pytest.mark.xfail), False) ARRAYS = ( [1.0, 2.0, 3.0], [1.0 + 0j, 2.0 + 0j, 3.0 + 0j], diff --git a/tests/integration/test_matmul.py b/tests/integration/test_matmul.py index 7a0759e0d..66f6ad89a 100644 --- a/tests/integration/test_matmul.py +++ b/tests/integration/test_matmul.py @@ -149,8 +149,7 @@ def test_out_invalid_shape_DIVERGENCE(self): @pytest.mark.parametrize( ("dtype", "out_dtype", "casting"), - ((None, np.int64, "same_kind"), (float, str, "safe")), - ids=("direct", "intermediate"), + ((None, np.int64, "same_kind"),), ) def test_out_invalid_dtype(self, dtype, out_dtype, casting): expected_exc = TypeError @@ -187,9 +186,8 @@ def test_invalid_casting_dtype(self, casting_dtype): with pytest.raises(expected_exc): num.matmul(A_num, B_num, casting=casting, dtype=dtype) - @pytest.mark.parametrize( - "dtype", (str, pytest.param(float, marks=pytest.mark.xfail)), ids=str - ) + @pytest.mark.xfail + @pytest.mark.parametrize("dtype", (float,), ids=str) def test_invalid_casting(self, dtype): expected_exc = ValueError casting = "unknown" diff --git a/tests/integration/test_repeat.py b/tests/integration/test_repeat.py index 1ea9eadef..3023f97c8 100644 --- a/tests/integration/test_repeat.py +++ b/tests/integration/test_repeat.py @@ -31,22 +31,6 @@ def test_repeats_none(array): num.repeat(array, None) -@pytest.mark.parametrize("repeats", (-3, [], [-3], [2, 3])) -def test_array_none_invalid(repeats): - expected_exc = ValueError - with pytest.raises(expected_exc): - np.repeat(None, repeats) - with pytest.raises(expected_exc): - num.repeat(None, repeats) - - -@pytest.mark.parametrize("repeats", (3, [0], [3], 4.7, [4.7])) -def test_array_none_valid(repeats): - res_num = num.repeat(None, repeats) - res_np = np.repeat(None, repeats) - assert np.array_equal(res_np, res_num) - - @pytest.mark.parametrize("repeats", (-3, 0, 3, 4.7, [], [-3], [0], [3], [4.7])) def test_array_empty_repeats_valid(repeats): res_np = np.repeat([], repeats) diff --git a/tests/integration/test_sort.py b/tests/integration/test_sort.py index 81e06d86a..1fdfc2f13 100644 --- a/tests/integration/test_sort.py +++ b/tests/integration/test_sort.py @@ -64,24 +64,6 @@ def test_arr_empty(self, arr): res_num = num.sort(arr) assert np.array_equal(res_num, res_np) - def test_structured_array_order(self): - dtype = [("name", "S10"), ("height", float), ("age", int)] - values = [ - ("Arthur", 1.8, 41), - ("Lancelot", 1.9, 38), - ("Galahad", 1.7, 38), - ] - a_np = np.array(values, dtype=dtype) - a_num = num.array(values, dtype=dtype) - - res_np = np.sort(a_np, order="height") - res_num = num.sort(a_num, order="height") - assert np.array_equal(res_np, res_num) - - res_np = np.sort(a_np, order=["age", "height"]) - res_num = num.sort(a_num, order=["age", "height"]) - assert np.array_equal(res_np, res_num) - def test_axis_out_bound(self): arr = [-1, 0, 1, 2, 10] with pytest.raises(ValueError): diff --git a/tests/integration/test_split.py b/tests/integration/test_split.py index d943f77b7..7646d478b 100644 --- a/tests/integration/test_split.py +++ b/tests/integration/test_split.py @@ -87,18 +87,6 @@ class TestSplitErrors: this class is to test negative cases """ - @pytest.mark.xfail - def test_array_none(self): - expected_exc = AttributeError - with pytest.raises(expected_exc): - np.split(None, 1) - # Numpy raises - # AttributeError: 'NoneType' object has no attribute 'shape' - with pytest.raises(expected_exc): - num.split(None, 1) - # cuNumeric raises - # ValueError: array(()) has less dimensions than axis(0) - @pytest.mark.parametrize("indices", (-2, 0, "hi", 1.0, None)) def test_indices_negative(self, indices): ary = num.arange(10) @@ -135,16 +123,6 @@ def test_axis_bigger(self): with pytest.raises(expected_exc): np.split(ary, 5, axis=axis) - @pytest.mark.parametrize("func_name", ARG_FUNCS) - def test_array_none_different_split(self, func_name): - expected_exc = ValueError - func_num = getattr(num, func_name) - func_np = getattr(np, func_name) - with pytest.raises(expected_exc): - func_np(None, 1) - with pytest.raises(expected_exc): - func_num(None, 1) - @pytest.mark.parametrize("indices", (-2, 0, "hi", 1.0, None)) @pytest.mark.parametrize("func_name", ARG_FUNCS) def test_indices_negative_different_split(self, func_name, indices):