Skip to content

Commit

Permalink
fix: replace axis_wrap_if_negative with maybe_posaxis, simpler an…
Browse files Browse the repository at this point in the history
…d more correct (#1986)

Co-authored-by: Angus Hollands <[email protected]>
  • Loading branch information
jpivarski and agoose77 authored Dec 9, 2022
1 parent fed8e3e commit d4fa292
Show file tree
Hide file tree
Showing 31 changed files with 496 additions and 634 deletions.
52 changes: 7 additions & 45 deletions src/awkward/_do.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import awkward as ak
from awkward._backends import Backend
from awkward.contents.content import ActionType, AxisMaybeNone, Content
from awkward.contents.content import ActionType, Content
from awkward.forms import form
from awkward.record import Record
from awkward.typing import Any
Expand Down Expand Up @@ -122,46 +122,8 @@ def getkey(layout, form, attribute):
return form, len(content), container


def axis_wrap_if_negative(
layout: Content | Record, axis: AxisMaybeNone
) -> AxisMaybeNone:
if isinstance(layout, Record):
if axis == 0:
raise ak._errors.wrap_error(
np.AxisError("Record type at axis=0 is a scalar, not an array")
)
return axis_wrap_if_negative(layout._array, axis)

else:
if axis is None or axis >= 0:
return axis

mindepth, maxdepth = layout.minmax_depth
depth = layout.purelist_depth
if mindepth == depth and maxdepth == depth:
posaxis = depth + axis
if posaxis < 0:
raise ak._errors.wrap_error(
np.AxisError(
f"axis={axis} exceeds the depth ({depth}) of this array"
)
)
return posaxis

elif mindepth + axis == 0:
raise ak._errors.wrap_error(
np.AxisError(
"axis={} exceeds the depth ({}) of at least one record field (or union possibility) of this array".format(
axis, depth
)
)
)

return axis


def local_index(layout: Content, axis: Integral):
return layout._local_index(axis, 0)
return layout._local_index(axis, 1)


def combinations(
Expand All @@ -184,7 +146,7 @@ def combinations(
raise ak._errors.wrap_error(
ValueError("if provided, the length of 'fields' must be 'n'")
)
return layout._combinations(n, replacement, recordlookup, parameters, axis, 0)
return layout._combinations(n, replacement, recordlookup, parameters, axis, 1)


def is_unique(layout, axis: Integral | None = None) -> bool:
Expand Down Expand Up @@ -248,7 +210,7 @@ def unique(layout: Content, axis=None):
def pad_none(
layout: Content, length: Integral, axis: Integral, clip: bool = False
) -> Content:
return layout._pad_none(length, axis, 0, clip)
return layout._pad_none(length, axis, 1, clip)


def completely_flatten(
Expand Down Expand Up @@ -281,8 +243,8 @@ def completely_flatten(
return tuple(arrays)


def flatten(layout: Content, axis: Integral = 1, depth: Integral = 0) -> Content:
offsets, flattened = layout._offsets_and_flattened(axis, depth)
def flatten(layout: Content, axis: Integral = 1) -> Content:
offsets, flattened = layout._offsets_and_flattened(axis, 1)
return flattened


Expand All @@ -295,7 +257,7 @@ def fill_none(layout: Content, value: Content) -> Content:


def num(layout, axis):
return layout._num(axis)
return layout._num(axis, 0)


def mergeable(one: Content, two: Content, mergebool: bool = True) -> bool:
Expand Down
19 changes: 19 additions & 0 deletions src/awkward/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,3 +825,22 @@ def _impl(array):
raise ak._errors.wrap_error(
ValueError(f"{module.__name__} is not supported by to_arraylib")
)


def maybe_posaxis(layout, axis, depth):
if isinstance(layout, ak.record.Record):
if axis == 0:
raise ak._errors.wrap_error(
np.AxisError("Record type at axis=0 is a scalar, not an array")
)
return maybe_posaxis(layout._array, axis, depth)

if axis >= 0:
return axis

else:
is_branching, additional_depth = layout.branch_depth
if not is_branching:
return axis + depth + additional_depth - 1
else:
return None
3 changes: 0 additions & 3 deletions src/awkward/contents/bitmaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -434,9 +434,6 @@ def _getitem_next(self, head, tail, advanced):
def project(self, mask=None):
return self.to_ByteMaskedArray().project(mask)

def _num(self, axis, depth=0):
return self.to_ByteMaskedArray()._num(axis, depth)

def _offsets_and_flattened(self, axis, depth):
return self.to_ByteMaskedArray._offsets_and_flattened(axis, depth)

Expand Down
46 changes: 14 additions & 32 deletions src/awkward/contents/bytemaskedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,34 +556,16 @@ def project(self, mask=None):

return self._content._carry(nextcarry, False)

def _num(self, axis, depth=0):
posaxis = ak._do.axis_wrap_if_negative(self, axis)
if posaxis == depth:
out = self.length
if ak._util.is_integer(out):
return np.int64(out)
else:
return out
else:
_, nextcarry, outindex = self._nextcarry_outindex(self._backend)

next = self._content._carry(nextcarry, False)
out = next._num(posaxis, depth)

return ak.contents.IndexedOptionArray.simplified(
outindex, out, parameters=self._parameters
)

def _offsets_and_flattened(self, axis, depth):
posaxis = ak._do.axis_wrap_if_negative(self, axis)
if posaxis == depth:
posaxis = ak._util.maybe_posaxis(self, axis, depth)
if posaxis is not None and posaxis + 1 == depth:
raise ak._errors.wrap_error(np.AxisError("axis=0 not allowed for flatten"))
else:
numnull, nextcarry, outindex = self._nextcarry_outindex(self._backend)

next = self._content._carry(nextcarry, False)

offsets, flattened = next._offsets_and_flattened(posaxis, depth)
offsets, flattened = next._offsets_and_flattened(axis, depth)

if offsets.length == 0:
return (
Expand Down Expand Up @@ -672,14 +654,14 @@ def _fill_none(self, value: Content) -> Content:
return self.to_IndexedOptionArray64()._fill_none(value)

def _local_index(self, axis, depth):
posaxis = ak._do.axis_wrap_if_negative(self, axis)
if posaxis == depth:
posaxis = ak._util.maybe_posaxis(self, axis, depth)
if posaxis is not None and posaxis + 1 == depth:
return self._local_index_axis0()
else:
_, nextcarry, outindex = self._nextcarry_outindex(self._backend)

next = self._content._carry(nextcarry, False)
out = next._local_index(posaxis, depth)
out = next._local_index(axis, depth)
return ak.contents.IndexedOptionArray.simplified(
outindex, out, parameters=self._parameters
)
Expand Down Expand Up @@ -749,15 +731,15 @@ def _combinations(self, n, replacement, recordlookup, parameters, axis, depth):
raise ak._errors.wrap_error(
ValueError("in combinations, 'n' must be at least 1")
)
posaxis = ak._do.axis_wrap_if_negative(self, axis)
if posaxis == depth:
posaxis = ak._util.maybe_posaxis(self, axis, depth)
if posaxis is not None and posaxis + 1 == depth:
return self._combinations_axis0(n, replacement, recordlookup, parameters)
else:
_, nextcarry, outindex = self._nextcarry_outindex(self._backend)

next = self._content._carry(nextcarry, True)
out = next._combinations(
n, replacement, recordlookup, parameters, posaxis, depth
n, replacement, recordlookup, parameters, axis, depth
)
return ak.contents.IndexedOptionArray.simplified(
outindex, out, parameters=parameters
Expand Down Expand Up @@ -938,10 +920,10 @@ def _nbytes_part(self):
return self.mask._nbytes_part() + self.content._nbytes_part()

def _pad_none(self, target, axis, depth, clip):
posaxis = ak._do.axis_wrap_if_negative(self, axis)
if posaxis == depth:
posaxis = ak._util.maybe_posaxis(self, axis, depth)
if posaxis is not None and posaxis + 1 == depth:
return self._pad_none_axis0(target, clip)
elif posaxis == depth + 1:
elif posaxis is not None and posaxis + 1 == depth + 1:
mask = ak.index.Index8(self.mask_as_bool(valid_when=False))
index = ak.index.Index64.empty(
mask.length, nplike=self._backend.index_nplike
Expand All @@ -961,14 +943,14 @@ def _pad_none(self, target, axis, depth, clip):
self._mask.length,
)
)
next = self.project()._pad_none(target, posaxis, depth, clip)
next = self.project()._pad_none(target, axis, depth, clip)
return ak.contents.IndexedOptionArray.simplified(
index, next, parameters=self._parameters
)
else:
return ak.contents.ByteMaskedArray(
self._mask,
self._content._pad_none(target, posaxis, depth, clip),
self._content._pad_none(target, axis, depth, clip),
self._valid_when,
parameters=self._parameters,
)
Expand Down
38 changes: 15 additions & 23 deletions src/awkward/contents/emptyarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -158,23 +158,9 @@ def _getitem_next(self, head, tail, advanced):
else:
raise ak._errors.wrap_error(AssertionError(repr(head)))

def _num(self, axis, depth=0):
posaxis = ak._do.axis_wrap_if_negative(self, axis)

if posaxis == depth:
out = self.length
if ak._util.is_integer(out):
return np.int64(out)
else:
return out
else:
# TODO: This was changed to use `nplike` instead of `index_nplike`. Is this OK?
out = ak.index.Index64.empty(0, nplike=self._backend.nplike)
return ak.contents.NumpyArray(out, parameters=None, backend=self._backend)

def _offsets_and_flattened(self, axis, depth):
posaxis = ak._do.axis_wrap_if_negative(self, axis)
if posaxis == depth:
posaxis = ak._util.maybe_posaxis(self, axis, depth)
if posaxis is not None and posaxis + 1 == depth:
raise ak._errors.wrap_error(
np.AxisError(self, "axis=0 not allowed for flatten")
)
Expand Down Expand Up @@ -203,11 +189,17 @@ def _fill_none(self, value: Content) -> Content:
return EmptyArray(parameters=self._parameters, backend=self._backend)

def _local_index(self, axis, depth):
return ak.contents.NumpyArray(
self._backend.nplike.empty(0, np.int64),
parameters=None,
backend=self._backend,
)
posaxis = ak._util.maybe_posaxis(self, axis, depth)
if posaxis is not None and posaxis + 1 == depth:
return ak.contents.NumpyArray(
self._backend.nplike.empty(0, np.int64),
parameters=None,
backend=self._backend,
)
else:
raise ak._errors.wrap_error(
np.AxisError(f"axis={axis} exceeds the depth of this array ({depth})")
)

def _numbers_to_type(self, name):
return ak.contents.EmptyArray(
Expand Down Expand Up @@ -287,8 +279,8 @@ def _nbytes_part(self):
return 0

def _pad_none(self, target, axis, depth, clip):
posaxis = ak._do.axis_wrap_if_negative(self, axis)
if posaxis != depth:
posaxis = ak._util.maybe_posaxis(self, axis, depth)
if posaxis is not None and posaxis + 1 != depth:
raise ak._errors.wrap_error(
np.AxisError(f"axis={axis} exceeds the depth of this array ({depth})")
)
Expand Down
39 changes: 14 additions & 25 deletions src/awkward/contents/indexedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,24 +388,13 @@ def project(self, mask=None):
)
)

def _num(self, axis, depth=0):
posaxis = ak._do.axis_wrap_if_negative(self, axis)
if posaxis == depth:
out = self.length
if ak._util.is_integer(out):
return np.int64(out)
else:
return out
else:
return self.project()._num(posaxis, depth)

def _offsets_and_flattened(self, axis, depth):
posaxis = ak._do.axis_wrap_if_negative(self, axis)
if posaxis == depth:
posaxis = ak._util.maybe_posaxis(self, axis, depth)
if posaxis is not None and posaxis + 1 == depth:
raise ak._errors.wrap_error(np.AxisError("axis=0 not allowed for flatten"))

else:
return self.project()._offsets_and_flattened(posaxis, depth)
return self.project()._offsets_and_flattened(axis, depth)

def _mergeable_next(self, other, mergebool):
if isinstance(
Expand Down Expand Up @@ -602,11 +591,11 @@ def _fill_none(self, value: Content) -> Content:
)

def _local_index(self, axis, depth):
posaxis = ak._do.axis_wrap_if_negative(self, axis)
if posaxis == depth:
posaxis = ak._util.maybe_posaxis(self, axis, depth)
if posaxis is not None and posaxis + 1 == depth:
return self._local_index_axis0()
else:
return self.project()._local_index(posaxis, depth)
return self.project()._local_index(axis, depth)

def _unique_index(self, index, sorted=True):
next = ak.index.Index64.zeros(self.length, nplike=self._backend.index_nplike)
Expand Down Expand Up @@ -871,12 +860,12 @@ def _sort_next(
)

def _combinations(self, n, replacement, recordlookup, parameters, axis, depth):
posaxis = ak._do.axis_wrap_if_negative(self, axis)
if posaxis == depth:
posaxis = ak._util.maybe_posaxis(self, axis)
if posaxis is not None and posaxis + 1 == depth:
return self._combinations_axis0(n, replacement, recordlookup, parameters)
else:
return self.project()._combinations(
n, replacement, recordlookup, parameters, posaxis, depth
n, replacement, recordlookup, parameters, axis, depth
)

def _reduce_next(
Expand Down Expand Up @@ -927,15 +916,15 @@ def _nbytes_part(self):
return self.index._nbytes_part() + self.content._nbytes_part()

def _pad_none(self, target, axis, depth, clip):
posaxis = ak._do.axis_wrap_if_negative(self, axis)
if posaxis == depth:
posaxis = ak._util.maybe_posaxis(self, axis, depth)
if posaxis is not None and posaxis + 1 == depth:
return self._pad_none_axis0(target, clip)
elif posaxis == depth + 1:
return self.project()._pad_none(target, posaxis, depth, clip)
elif posaxis is not None and posaxis + 1 == depth + 1:
return self.project()._pad_none(target, axis, depth, clip)
else:
return ak.contents.IndexedArray(
self._index,
self._content._pad_none(target, posaxis, depth, clip),
self._content._pad_none(target, axis, depth, clip),
parameters=self._parameters,
)

Expand Down
Loading

0 comments on commit d4fa292

Please sign in to comment.