Skip to content

Commit

Permalink
Fall back to inline mapping on single-element accesses
Browse files Browse the repository at this point in the history
Also fix the printing of fallback warnings
  • Loading branch information
manopapad committed Jan 13, 2024
1 parent c1bfd9d commit b5a4f95
Show file tree
Hide file tree
Showing 3 changed files with 49 additions and 18 deletions.
43 changes: 35 additions & 8 deletions cunumeric/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
from __future__ import annotations

import operator
import warnings
from functools import reduce, wraps
from inspect import signature
from typing import (
Expand Down Expand Up @@ -83,6 +82,15 @@
P = ParamSpec("P")


_WARN_SINGLE_ELEM_ACCESS = (
"cuNumeric detected a single-element access, and is proactively pulling "
"the entire array onto a single memory, to serve this and future "
"accesses. This may result in blocking and increased memory consumption. "
"Looping over ndarray elements is not efficient, please consider using "
"full-array operations instead."
)


def add_boilerplate(
*array_params: str,
) -> Callable[[Callable[P, R]], Callable[P, R]]:
Expand Down Expand Up @@ -423,10 +431,9 @@ def __array_function__(
what = f"the requested combination of arguments to {what}"

# We cannot handle this call, so we will fall back to NumPy.
warnings.warn(
runtime.warn(
FALLBACK_WARNING.format(what=what),
category=RuntimeWarning,
stacklevel=4,
)
args = deep_apply(args, maybe_convert_to_np_ndarray)
kwargs = deep_apply(kwargs, maybe_convert_to_np_ndarray)
Expand Down Expand Up @@ -469,10 +476,9 @@ def __array_ufunc__(
what = f"the requested combination of arguments to {what}"

# We cannot handle this ufunc call, so we will fall back to NumPy.
warnings.warn(
runtime.warn(
FALLBACK_WARNING.format(what=what),
category=RuntimeWarning,
stacklevel=3,
)
inputs = deep_apply(inputs, maybe_convert_to_np_ndarray)
kwargs = deep_apply(kwargs, maybe_convert_to_np_ndarray)
Expand Down Expand Up @@ -1027,6 +1033,14 @@ def _convert_key(self, key: Any, first: bool = True) -> Any:

return key._thunk

def _is_single_elem_access(self, key: Any) -> bool:
# Just do a quick check to catch literal uses of scalar indices
return (
isinstance(key, tuple)
and len(key) == self.ndim
and all(np.isscalar(k) for k in key)
)

@add_boilerplate()
def __getitem__(self, key: Any) -> ndarray:
"""a.__getitem__(key, /)
Expand All @@ -1035,6 +1049,12 @@ def __getitem__(self, key: Any) -> ndarray:
"""
key = self._convert_key(key)
if self._is_single_elem_access(key):
runtime.warn(
_WARN_SINGLE_ELEM_ACCESS,
category=RuntimeWarning,
)
return self.__array__()[key]
return ndarray(shape=None, thunk=self._thunk.get_item(key))

def __gt__(self, rhs: Any) -> ndarray:
Expand Down Expand Up @@ -1664,19 +1684,26 @@ def __rxor__(self, lhs: Any) -> ndarray:
return bitwise_xor(lhs, self)

# __setattr__
@add_boilerplate("value")
def __setitem__(self, key: Any, value: ndarray) -> None:
def __setitem__(self, key: Any, raw_value: Any) -> None:
"""__setitem__(key, value, /)
Set ``self[key]=value``.
"""
check_writeable(self)
key = self._convert_key(key)
if self._is_single_elem_access(key):
runtime.warn(
_WARN_SINGLE_ELEM_ACCESS,
category=RuntimeWarning,
)
self.__array__()[key] = raw_value
return
value = convert_to_cunumeric_ndarray(raw_value)
if value.dtype != self.dtype:
temp = ndarray(value.shape, dtype=self.dtype, inputs=(value,))
temp._thunk.convert(value._thunk)
value = temp
key = self._convert_key(key)
self._thunk.set_item(key, value._thunk)

def __setstate__(self, state: Any) -> None:
Expand Down
7 changes: 2 additions & 5 deletions cunumeric/coverage.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
#
from __future__ import annotations

import warnings
from dataclasses import dataclass
from functools import WRAPPER_ASSIGNMENTS, wraps
from types import (
Expand All @@ -41,7 +40,7 @@

from .runtime import runtime
from .settings import settings
from .utils import deep_apply, find_last_user_frames, find_last_user_stacklevel
from .utils import deep_apply, find_last_user_frames

__all__ = ("clone_module", "clone_class")

Expand Down Expand Up @@ -182,10 +181,8 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:

@wraps(func, assigned=_UNIMPLEMENTED_COPIED_ATTRS)
def wrapper(*args: Any, **kwargs: Any) -> Any:
stacklevel = find_last_user_stacklevel()
warnings.warn(
runtime.warn(
FALLBACK_WARNING.format(what=name),
stacklevel=stacklevel,
category=RuntimeWarning,
)
if fallback:
Expand Down
17 changes: 12 additions & 5 deletions cunumeric/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,20 @@ def is_advanced_indexing(key: Any) -> bool:
return True


_INTERNAL_MODULE_PREFIXES = ("cunumeric.", "legate.core.")


def is_internal_frame(frame: FrameType) -> bool:
if "__name__" not in frame.f_globals:
return False
name = frame.f_globals["__name__"]
return any(name.startswith(prefix) for prefix in _INTERNAL_MODULE_PREFIXES)


def find_last_user_stacklevel() -> int:
stacklevel = 1
for frame, _ in traceback.walk_stack(None):
if not frame.f_globals["__name__"].startswith("cunumeric"):
if not is_internal_frame(frame):
break
stacklevel += 1
return stacklevel
Expand All @@ -83,10 +93,7 @@ def get_line_number_from_frame(frame: FrameType) -> str:

def find_last_user_frames(top_only: bool = True) -> str:
for last, _ in traceback.walk_stack(None):
if "__name__" not in last.f_globals:
continue
name = last.f_globals["__name__"]
if not any(name.startswith(pkg) for pkg in ("cunumeric", "legate")):
if not is_internal_frame(last):
break

if top_only:
Expand Down

0 comments on commit b5a4f95

Please sign in to comment.