Skip to content

Commit

Permalink
debug: slicing on gpu
Browse files Browse the repository at this point in the history
  • Loading branch information
ianna committed Sep 16, 2024
1 parent 1c368f1 commit b18a029
Show file tree
Hide file tree
Showing 6 changed files with 804 additions and 669 deletions.
28 changes: 21 additions & 7 deletions src/awkward/_slicing.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,19 @@

SliceItem: TypeAlias = "int | slice | str | None | Ellipsis | ArrayLike | Content"

import functools

def trace_function_calls(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
print(f"_slicing.py: Calling function: {func.__name__}")
result = func(*args, **kwargs)
print(f"_slicing.py: Function {func.__name__} returned {result}")
return result
return wrapper


@trace_function_calls
def normalize_slice(slice_: slice, *, nplike: NumpyLike) -> slice:
"""
Args:
Expand Down Expand Up @@ -59,13 +71,15 @@ def __repr__(self):
S = TypeVar("S", bound=Sequence)


@trace_function_calls
def head_tail(sequence: S[T]) -> tuple[T | type(NO_HEAD), S[T]]:
if len(sequence) == 0:
return NO_HEAD, ()
else:
return sequence[0], sequence[1:]


@trace_function_calls
def prepare_advanced_indexing(items, backend: Backend):
"""Broadcast index objects to satisfy NumPy indexing rules
Expand Down Expand Up @@ -177,7 +191,7 @@ def prepare_advanced_indexing(items, backend: Backend):
)
return tuple(prepared)


@trace_function_calls
def normalize_integer_like(x) -> int | ArrayLike:
if is_array_like(x):
if np.issubdtype(x.dtype, np.integer) and x.ndim == 0:
Expand All @@ -187,7 +201,7 @@ def normalize_integer_like(x) -> int | ArrayLike:
else:
return int(x)


@trace_function_calls
def normalise_item(item, backend: Backend) -> SliceItem:
"""
Args:
Expand Down Expand Up @@ -300,12 +314,12 @@ def normalise_item(item, backend: Backend) -> SliceItem:
+ repr(item).replace("\n", "\n ")
)


@trace_function_calls
def normalise_items(where: Sequence, backend: Backend) -> list:
# First prepare items for broadcasting into like-types
return [normalise_item(x, backend=backend) for x in where]


@trace_function_calls
def _normalise_item_RegularArray_to_ListOffsetArray64(item: Content) -> Content:
if isinstance(item, ak.contents.RegularArray):
next = item.to_ListOffsetArray64()
Expand All @@ -321,7 +335,7 @@ def _normalise_item_RegularArray_to_ListOffsetArray64(item: Content) -> Content:
else:
raise AssertionError(type(item))


@trace_function_calls
def _normalise_item_nested(item: Content) -> Content:
if isinstance(item, ak.contents.EmptyArray):
# policy: unknown -> int
Expand Down Expand Up @@ -460,7 +474,7 @@ def _normalise_item_nested(item: Content) -> Content:
+ repr(item).replace("\n", "\n ")
)


@trace_function_calls
def _normalise_item_bool_to_int(item: Content, backend: Backend) -> Content:
"""
Args:
Expand Down Expand Up @@ -650,7 +664,7 @@ def _normalise_item_bool_to_int(item: Content, backend: Backend) -> Content:
else:
raise AssertionError(type(item))


@trace_function_calls
def getitem_next_array_wrap(
outcontent: Content, shape: tuple[int], outer_length: int = 0
) -> Content:
Expand Down
29 changes: 29 additions & 0 deletions src/awkward/contents/content.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,16 @@
JSONValueType: TypeAlias = """
float | int | str | list[JSONValueType] | dict[str, JSONValueType]
"""
import functools

def trace_function_calls(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
print(f"content.py: Calling function: {func.__name__}")
result = func(*args, **kwargs)
print(f"content.py: Function {func.__name__} returned {result}")
return result
return wrapper

class ImplementsApplyAction(Protocol):
def __call__(
Expand Down Expand Up @@ -298,6 +307,7 @@ def __iter__(self):
for i in range(len(self)):
yield self._getitem_at(i)

@trace_function_calls
def _getitem_next_field(
self,
head: SliceItem | tuple,
Expand All @@ -307,6 +317,7 @@ def _getitem_next_field(
nexthead, nexttail = ak._slicing.head_tail(tail)
return self._getitem_field(head)._getitem_next(nexthead, nexttail, advanced)

@trace_function_calls
def _getitem_next_fields(
self, head: SliceItem, tail: tuple[SliceItem, ...], advanced: Index | None
) -> Content:
Expand All @@ -321,6 +332,8 @@ def _getitem_next_fields(
nexthead, nexttail, advanced
)


@trace_function_calls
def _getitem_next_newaxis(
self, tail: tuple[SliceItem, ...], advanced: Index | None
) -> RegularArray:
Expand All @@ -329,6 +342,7 @@ def _getitem_next_newaxis(
self._getitem_next(nexthead, nexttail, advanced), 1, 0, parameters=None
)

@trace_function_calls
def _getitem_next_ellipsis(
self, tail: tuple[SliceItem, ...], advanced: Index | None
) -> Content:
Expand All @@ -353,6 +367,7 @@ def _getitem_next_ellipsis(
else:
return self._getitem_next(slice(None), (Ellipsis, *tail), advanced)

@trace_function_calls
def _getitem_next_regular_missing(
self,
head: IndexedOptionArray,
Expand Down Expand Up @@ -395,6 +410,7 @@ def _getitem_next_regular_missing(
out, indexlength, 1, parameters=self._parameters
)

@trace_function_calls
def _getitem_next_missing_jagged(
self, head: Content, tail, advanced: Index | None, that: Content
) -> RegularArray:
Expand Down Expand Up @@ -447,6 +463,7 @@ def _getitem_next_missing_jagged(
out, index.length, 1, parameters=self._parameters
)

@trace_function_calls
def _getitem_next_missing(
self,
head: IndexedOptionArray,
Expand Down Expand Up @@ -508,9 +525,11 @@ def _getitem_next_missing(
f"FIXME: unhandled case of SliceMissing with {nextcontent}"
)

@trace_function_calls
def __getitem__(self, where):
return self._getitem(where)

@trace_function_calls
def _getitem(self, where):
if is_integer_like(where):
return self._getitem_at(ak._slicing.normalize_integer_like(where))
Expand Down Expand Up @@ -693,25 +712,31 @@ def _getitem(self, where):
+ repr(where).replace("\n", "\n ")
)

@trace_function_calls
def _is_getitem_at_placeholder(self) -> bool:
raise NotImplementedError

@trace_function_calls
def _getitem_at(self, where: IndexType):
raise NotImplementedError

@trace_function_calls
def _getitem_range(self, start: IndexType, stop: IndexType) -> Content:
raise NotImplementedError

@trace_function_calls
def _getitem_field(
self, where: str | SupportsIndex, only_fields: tuple[str, ...] = ()
) -> Content:
raise NotImplementedError

@trace_function_calls
def _getitem_fields(
self, where: list[str], only_fields: tuple[str, ...] = ()
) -> Content:
raise NotImplementedError

@trace_function_calls
def _getitem_next(
self,
head: SliceItem | tuple,
Expand All @@ -720,6 +745,7 @@ def _getitem_next(
) -> Content:
raise NotImplementedError

@trace_function_calls
def _getitem_next_jagged(
self,
slicestarts: Index,
Expand All @@ -729,9 +755,11 @@ def _getitem_next_jagged(
) -> Content:
raise NotImplementedError

@trace_function_calls
def _carry(self, carry: Index, allow_lazy: bool) -> Content:
raise NotImplementedError

@trace_function_calls
def _local_index_axis0(self) -> NumpyArray:
localindex = Index64.empty(self.length, self._backend.index_nplike)
self._backend.maybe_kernel_error(
Expand All @@ -744,6 +772,7 @@ def _local_index_axis0(self) -> NumpyArray:
localindex.data, parameters=None, backend=self._backend
)

@trace_function_calls
def _mergeable_next(self, other: Content, mergebool: bool) -> bool:
raise NotImplementedError

Expand Down
23 changes: 23 additions & 0 deletions src/awkward/contents/listarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,16 @@

np = NumpyMetadata.instance()

import functools

def trace_function_calls(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
print(f"listarray.py: Calling function: {func.__name__}")
result = func(*args, **kwargs)
print(f"listarray.py: Function {func.__name__} returned {result}")
return result
return wrapper

@final
class ListArray(ListMeta[Content], Content):
Expand Down Expand Up @@ -298,14 +308,17 @@ def to_RegularArray(self):
offsets = self._compact_offsets64(True)
return self._broadcast_tooffsets64(offsets).to_RegularArray()

@trace_function_calls
def _getitem_nothing(self):
return self._content._getitem_range(0, 0)

@trace_function_calls
def _is_getitem_at_placeholder(self) -> bool:
return isinstance(self._starts, PlaceholderArray) or isinstance(
self._stops, PlaceholderArray
)

@trace_function_calls
def _getitem_at(self, where: IndexType):
if not self._backend.nplike.known_data:
self._touch_data(recursive=False)
Expand All @@ -318,6 +331,7 @@ def _getitem_at(self, where: IndexType):
start, stop = self._starts[where], self._stops[where]
return self._content._getitem_range(start, stop)

@trace_function_calls
def _getitem_range(self, start: IndexType, stop: IndexType) -> Content:
if not self._backend.nplike.known_data:
self._touch_shape(recursive=False)
Expand All @@ -330,6 +344,7 @@ def _getitem_range(self, start: IndexType, stop: IndexType) -> Content:
parameters=self._parameters,
)

@trace_function_calls
def _getitem_field(
self, where: str | SupportsIndex, only_fields: tuple[str, ...] = ()
) -> Content:
Expand All @@ -340,6 +355,7 @@ def _getitem_field(
parameters=None,
)

@trace_function_calls
def _getitem_fields(
self, where: list[str | SupportsIndex], only_fields: tuple[str, ...] = ()
) -> Content:
Expand All @@ -350,6 +366,7 @@ def _getitem_fields(
parameters=None,
)

@trace_function_calls
def _carry(self, carry: Index, allow_lazy: bool) -> Content:
assert isinstance(carry, ak.index.Index)

Expand All @@ -363,6 +380,7 @@ def _carry(self, carry: Index, allow_lazy: bool) -> Content:
nextstarts, nextstops, self._content, parameters=self._parameters
)

@trace_function_calls
def _compact_offsets64(self, start_at_zero):
starts_len = self._starts.length
out = ak.index.Index64.empty(
Expand All @@ -389,6 +407,7 @@ def _compact_offsets64(self, start_at_zero):
)
return out

@trace_function_calls
def _broadcast_tooffsets64(self, offsets: Index) -> ListOffsetArray:
self._touch_data(recursive=False)
offsets._touch_data()
Expand Down Expand Up @@ -443,6 +462,7 @@ def _broadcast_tooffsets64(self, offsets: Index) -> ListOffsetArray:

return ListOffsetArray(offsets, nextcontent, parameters=self._parameters)

@trace_function_calls
def _getitem_next_jagged(
self, slicestarts: Index, slicestops: Index, slicecontent: Content, tail
) -> Content:
Expand Down Expand Up @@ -698,6 +718,7 @@ def _getitem_next_jagged(
f"expected Index/IndexedOptionArray/ListOffsetArray in ListArray._getitem_next_jagged, got {type(slicecontent).__name__}"
)

@trace_function_calls
def _getitem_next(
self,
head: SliceItem | tuple,
Expand Down Expand Up @@ -1058,9 +1079,11 @@ def _getitem_next(
else:
raise AssertionError(repr(head))

@trace_function_calls
def _offsets_and_flattened(self, axis: int, depth: int) -> tuple[Index, Content]:
return self.to_ListOffsetArray64(True)._offsets_and_flattened(axis, depth)

@trace_function_calls
def _mergeable_next(self, other: Content, mergebool: bool) -> bool:
# Is the other content is an identity, or a union?
if other.is_identity_like or other.is_union:
Expand Down
Loading

0 comments on commit b18a029

Please sign in to comment.