From b5a4f958278027b6aa1ce4684732b2eaa084ae64 Mon Sep 17 00:00:00 2001 From: Manolis Papadakis Date: Fri, 12 Jan 2024 16:35:21 -0800 Subject: [PATCH] Fall back to inline mapping on single-element accesses Also fix the printing of fallback warnings --- cunumeric/array.py | 43 +++++++++++++++++++++++++++++++++++-------- cunumeric/coverage.py | 7 ++----- cunumeric/utils.py | 17 ++++++++++++----- 3 files changed, 49 insertions(+), 18 deletions(-) diff --git a/cunumeric/array.py b/cunumeric/array.py index 8bfc5178a..ad6158ab3 100644 --- a/cunumeric/array.py +++ b/cunumeric/array.py @@ -15,7 +15,6 @@ from __future__ import annotations import operator -import warnings from functools import reduce, wraps from inspect import signature from typing import ( @@ -83,6 +82,15 @@ P = ParamSpec("P") +_WARN_SINGLE_ELEM_ACCESS = ( + "cuNumeric detected a single-element access, and is proactively pulling " + "the entire array onto a single memory, to serve this and future " + "accesses. This may result in blocking and increased memory consumption. " + "Looping over ndarray elements is not efficient, please consider using " + "full-array operations instead." +) + + def add_boilerplate( *array_params: str, ) -> Callable[[Callable[P, R]], Callable[P, R]]: @@ -423,10 +431,9 @@ def __array_function__( what = f"the requested combination of arguments to {what}" # We cannot handle this call, so we will fall back to NumPy. - warnings.warn( + runtime.warn( FALLBACK_WARNING.format(what=what), category=RuntimeWarning, - stacklevel=4, ) args = deep_apply(args, maybe_convert_to_np_ndarray) kwargs = deep_apply(kwargs, maybe_convert_to_np_ndarray) @@ -469,10 +476,9 @@ def __array_ufunc__( what = f"the requested combination of arguments to {what}" # We cannot handle this ufunc call, so we will fall back to NumPy. - warnings.warn( + runtime.warn( FALLBACK_WARNING.format(what=what), category=RuntimeWarning, - stacklevel=3, ) inputs = deep_apply(inputs, maybe_convert_to_np_ndarray) kwargs = deep_apply(kwargs, maybe_convert_to_np_ndarray) @@ -1027,6 +1033,14 @@ def _convert_key(self, key: Any, first: bool = True) -> Any: return key._thunk + def _is_single_elem_access(self, key: Any) -> bool: + # Just do a quick check to catch literal uses of scalar indices + return ( + isinstance(key, tuple) + and len(key) == self.ndim + and all(np.isscalar(k) for k in key) + ) + @add_boilerplate() def __getitem__(self, key: Any) -> ndarray: """a.__getitem__(key, /) @@ -1035,6 +1049,12 @@ def __getitem__(self, key: Any) -> ndarray: """ key = self._convert_key(key) + if self._is_single_elem_access(key): + runtime.warn( + _WARN_SINGLE_ELEM_ACCESS, + category=RuntimeWarning, + ) + return self.__array__()[key] return ndarray(shape=None, thunk=self._thunk.get_item(key)) def __gt__(self, rhs: Any) -> ndarray: @@ -1664,19 +1684,26 @@ def __rxor__(self, lhs: Any) -> ndarray: return bitwise_xor(lhs, self) # __setattr__ - @add_boilerplate("value") - def __setitem__(self, key: Any, value: ndarray) -> None: + def __setitem__(self, key: Any, raw_value: Any) -> None: """__setitem__(key, value, /) Set ``self[key]=value``. """ check_writeable(self) + key = self._convert_key(key) + if self._is_single_elem_access(key): + runtime.warn( + _WARN_SINGLE_ELEM_ACCESS, + category=RuntimeWarning, + ) + self.__array__()[key] = raw_value + return + value = convert_to_cunumeric_ndarray(raw_value) if value.dtype != self.dtype: temp = ndarray(value.shape, dtype=self.dtype, inputs=(value,)) temp._thunk.convert(value._thunk) value = temp - key = self._convert_key(key) self._thunk.set_item(key, value._thunk) def __setstate__(self, state: Any) -> None: diff --git a/cunumeric/coverage.py b/cunumeric/coverage.py index a8e57285f..c621129c2 100644 --- a/cunumeric/coverage.py +++ b/cunumeric/coverage.py @@ -14,7 +14,6 @@ # from __future__ import annotations -import warnings from dataclasses import dataclass from functools import WRAPPER_ASSIGNMENTS, wraps from types import ( @@ -41,7 +40,7 @@ from .runtime import runtime from .settings import settings -from .utils import deep_apply, find_last_user_frames, find_last_user_stacklevel +from .utils import deep_apply, find_last_user_frames __all__ = ("clone_module", "clone_class") @@ -182,10 +181,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Any: @wraps(func, assigned=_UNIMPLEMENTED_COPIED_ATTRS) def wrapper(*args: Any, **kwargs: Any) -> Any: - stacklevel = find_last_user_stacklevel() - warnings.warn( + runtime.warn( FALLBACK_WARNING.format(what=name), - stacklevel=stacklevel, category=RuntimeWarning, ) if fallback: diff --git a/cunumeric/utils.py b/cunumeric/utils.py index 8c2d70140..ade176411 100644 --- a/cunumeric/utils.py +++ b/cunumeric/utils.py @@ -68,10 +68,20 @@ def is_advanced_indexing(key: Any) -> bool: return True +_INTERNAL_MODULE_PREFIXES = ("cunumeric.", "legate.core.") + + +def is_internal_frame(frame: FrameType) -> bool: + if "__name__" not in frame.f_globals: + return False + name = frame.f_globals["__name__"] + return any(name.startswith(prefix) for prefix in _INTERNAL_MODULE_PREFIXES) + + def find_last_user_stacklevel() -> int: stacklevel = 1 for frame, _ in traceback.walk_stack(None): - if not frame.f_globals["__name__"].startswith("cunumeric"): + if not is_internal_frame(frame): break stacklevel += 1 return stacklevel @@ -83,10 +93,7 @@ def get_line_number_from_frame(frame: FrameType) -> str: def find_last_user_frames(top_only: bool = True) -> str: for last, _ in traceback.walk_stack(None): - if "__name__" not in last.f_globals: - continue - name = last.f_globals["__name__"] - if not any(name.startswith(pkg) for pkg in ("cunumeric", "legate")): + if not is_internal_frame(last): break if top_only: