Skip to content

Commit

Permalink
fix: 2881 akconcatenate fails trying to concatenate too many nested a…
Browse files Browse the repository at this point in the history
…rrays (#3207)

To resolve this we allow `UnionArray.nested_tags_index` and `UnionArray.simplified` to
handle incoming tags with type Index64. Resulting, simplified UnionArrays are
still limited to 128 distinct types/tags, but this limitation is much less likely to
create limitations in practice.

* New approach: add partial support for 64-bit tags

Expanding the template instances for:
* awkward_UnionArray_nestedfill_tags_index
* awkward_UnionArray_simplify_one
* awkward_UnionArray_simplify
These allow 64-bit tags in certain cases, for ak_concatenate.

* Ordering specializations in kernel-specification.yml

* UnionArray.simplify length tests moved.

Also a few style modifications.
The reason for moving those length checks
is so that I do not wish even temporarily hold on
to content tags with invalid data.

* Reverting change to batch index in ak_concatenate

Co-authored-by: Jim Pivarski <[email protected]>

---------

Co-authored-by: Jim Pivarski <[email protected]>
  • Loading branch information
tcawlfield and jpivarski authored Aug 8, 2024
1 parent 3ab54f8 commit af232f5
Show file tree
Hide file tree
Showing 7 changed files with 153 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,13 @@ ERROR awkward_UnionArray8_64_nestedfill_tags_index_64(
return awkward_UnionArray_nestedfill_tags_index<int8_t, int64_t, int64_t>(
totags, toindex, tmpstarts, tag, fromcounts, length);
}
ERROR awkward_UnionArray64_64_nestedfill_tags_index_64(
int64_t* totags,
int64_t* toindex,
int64_t* tmpstarts,
int64_t tag,
const int64_t* fromcounts,
int64_t length) {
return awkward_UnionArray_nestedfill_tags_index<int64_t, int64_t, int64_t>(
totags, toindex, tmpstarts, tag, fromcounts, length);
}
25 changes: 25 additions & 0 deletions awkward-cpp/src/cpu-kernels/awkward_UnionArray_simplify.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -258,3 +258,28 @@ ERROR awkward_UnionArray8_64_simplify8_64_to8_64(
length,
base);
}
ERROR awkward_UnionArray64_64_simplify8_64_to8_64(
int8_t* totags,
int64_t* toindex,
const int64_t* outertags,
const int64_t* outerindex,
const int8_t* innertags,
const int64_t* innerindex,
int64_t towhich,
int64_t innerwhich,
int64_t outerwhich,
int64_t length,
int64_t base) {
return awkward_UnionArray_simplify<int64_t, int64_t, int8_t, int64_t, int8_t, int64_t>(
totags,
toindex,
outertags,
outerindex,
innertags,
innerindex,
towhich,
innerwhich,
outerwhich,
length,
base);
}
19 changes: 19 additions & 0 deletions awkward-cpp/src/cpu-kernels/awkward_UnionArray_simplify_one.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -82,3 +82,22 @@ ERROR awkward_UnionArray8_64_simplify_one_to8_64(
length,
base);
}
ERROR awkward_UnionArray64_64_simplify_one_to8_64(
int8_t* totags,
int64_t* toindex,
const int64_t* fromtags,
const int64_t* fromindex,
int64_t towhich,
int64_t fromwhich,
int64_t length,
int64_t base) {
return awkward_UnionArray_simplify_one<int64_t, int64_t, int8_t, int64_t>(
totags,
toindex,
fromtags,
fromindex,
towhich,
fromwhich,
length,
base);
}
31 changes: 31 additions & 0 deletions kernel-specification.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3372,6 +3372,14 @@ kernels:

- name: awkward_UnionArray_nestedfill_tags_index
specializations:
- name: awkward_UnionArray64_64_nestedfill_tags_index_64
args:
- {name: totags, type: "List[int64_t]", dir: out}
- {name: toindex, type: "List[int64_t]", dir: out}
- {name: tmpstarts, type: "List[int64_t]", dir: out}
- {name: tag, type: "int64_t", dir: in, role: default}
- {name: fromcounts, type: "Const[List[int64_t]]", dir: in, role: default}
- {name: length, type: "int64_t", dir: in, role: default}
- name: awkward_UnionArray8_32_nestedfill_tags_index_64
args:
- {name: totags, type: "List[int8_t]", dir: out}
Expand Down Expand Up @@ -3503,6 +3511,19 @@ kernels:

- name: awkward_UnionArray_simplify
specializations:
- name: awkward_UnionArray64_64_simplify8_64_to8_64
args:
- {name: totags, type: "List[int8_t]", dir: out}
- {name: toindex, type: "List[int64_t]", dir: out}
- {name: outertags, type: "Const[List[int64_t]]", dir: in, role: UnionArray-tags}
- {name: outerindex, type: "Const[List[int64_t]]", dir: in, role: IndexedArray-index}
- {name: innertags, type: "Const[List[int8_t]]", dir: in, role: UnionArray2-tags}
- {name: innerindex, type: "Const[List[int64_t]]", dir: in, role: IndexedArray2-index}
- {name: towhich, type: "int64_t", dir: in, role: default}
- {name: innerwhich, type: "int64_t", dir: in, role: UnionArray1-which}
- {name: outerwhich, type: "int64_t", dir: in, role: UnionArray2-which}
- {name: length, type: "int64_t", dir: in, role: default}
- {name: base, type: "int64_t", dir: in, role: default}
- name: awkward_UnionArray8_32_simplify8_32_to8_64
args:
- {name: totags, type: "List[int8_t]", dir: out}
Expand Down Expand Up @@ -3645,6 +3666,16 @@ kernels:

- name: awkward_UnionArray_simplify_one
specializations:
- name: awkward_UnionArray64_64_simplify_one_to8_64
args:
- {name: totags, type: "List[int8_t]", dir: out}
- {name: toindex, type: "List[int64_t]", dir: out}
- {name: fromtags, type: "Const[List[int64_t]]", dir: in, role: UnionArray-tags}
- {name: fromindex, type: "Const[List[int64_t]]", dir: in, role: IndexedArray-index}
- {name: towhich, type: "int64_t", dir: in, role: default}
- {name: fromwhich, type: "int64_t", dir: in, role: UnionArray-which}
- {name: length, type: "int64_t", dir: in, role: default}
- {name: base, type: "int64_t", dir: in, role: default}
- name: awkward_UnionArray8_32_simplify_one_to8_64
args:
- {name: totags, type: "List[int8_t]", dir: out}
Expand Down
26 changes: 17 additions & 9 deletions src/awkward/contents/unionarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@

np = NumpyMetadata.instance()
numpy = Numpy.instance()
MAX_UNION_CONTENTS = 2**7 # We use int8 tags, 0-127


@final
Expand Down Expand Up @@ -230,6 +231,10 @@ def simplified(
parameters=None,
mergebool=False,
):
# Note: to help merge more than 128 arrays, tags *can* have type ak.index.Index64.
# This is only supported when index is also Index64,
# and all indexed contents are also Index64.
# We still require that this reduces to no more than 128 variants.
self_index = index
self_tags = tags
self_contents = contents
Expand Down Expand Up @@ -299,6 +304,10 @@ def simplified(

# Did we fail to merge any of the final outer contents with this inner union content?
if unmerged:
if len(contents) >= MAX_UNION_CONTENTS:
raise ValueError(
"UnionArray does not support more than 128 content types"
)
backend.maybe_kernel_error(
backend[
"awkward_UnionArray_simplify",
Expand Down Expand Up @@ -373,6 +382,10 @@ def simplified(
break

if unmerged:
if len(contents) >= MAX_UNION_CONTENTS:
raise ValueError(
"UnionArray does not support more than 128 content types"
)
backend.maybe_kernel_error(
backend[
"awkward_UnionArray_simplify_one",
Expand All @@ -393,11 +406,6 @@ def simplified(
)
contents.append(self_cont)

if len(contents) > 2**7:
raise NotImplementedError(
"FIXME: handle UnionArray with more than 127 contents"
)

# If any contents are non-categorical index types, we can merge them into the union
# This is safe, because any remaining index types at this point in the routine are not considered
# mergeable with the other contents. This means none of the other contents are option or index types,
Expand Down Expand Up @@ -1107,8 +1115,8 @@ def _reverse_merge(self, other):
)
)

if len(contents) > 2**7:
raise AssertionError("FIXME: handle UnionArray with more than 127 contents")
if len(contents) > MAX_UNION_CONTENTS:
raise ValueError("UnionArray cannot have more than 128 content types")

return ak.contents.UnionArray.simplified(
tags, index, contents, parameters=self._parameters
Expand Down Expand Up @@ -1236,8 +1244,8 @@ def _mergemany(self, others: Sequence[Content]) -> Content:

nextcontents.append(array)

if len(nextcontents) > 127:
raise ValueError("FIXME: handle UnionArray with more than 127 contents")
if len(nextcontents) > MAX_UNION_CONTENTS:
raise ValueError("UnionArray cannot have more than 128 content types")

next = ak.contents.UnionArray.simplified(
nexttags,
Expand Down
23 changes: 17 additions & 6 deletions src/awkward/operations/ak_concatenate.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,12 @@ def action(inputs, depth, backend, **kwargs):
prototype[start : start + size] = tag
start += size

tags = ak.index.Index8(
if len(regulararrays) < 2**7:
tags_cls = ak.index.Index8
else:
tags_cls = ak.index.Index64

tags = tags_cls(
backend.index_nplike.reshape(
backend.index_nplike.broadcast_to(
prototype, (length, prototype.size)
Expand All @@ -265,17 +270,18 @@ def action(inputs, depth, backend, **kwargs):
return (ak.contents.RegularArray(inner, prototype.size),)

elif all(
isinstance(x, ak.contents.Content)
and x.is_list
or (isinstance(x, ak.contents.NumpyArray) and x.data.ndim > 1)
or not isinstance(x, ak.contents.Content)
(isinstance(x, ak.contents.Content) and x.is_list) # Case 1
or (isinstance(x, ak.contents.NumpyArray) and x.data.ndim > 1) # Case 2
or not isinstance(x, ak.contents.Content) # Case 3: scalar value
for x in inputs
):
nextinputs = []
for x in inputs:
if isinstance(x, ak.contents.Content):
nextinputs.append(x)
else:
# Treat non-content inputs as scalars.
# These become arrays of matching length.
nextinputs.append(
ak.contents.ListOffsetArray(
ak.index.Index64(
Expand All @@ -302,7 +308,7 @@ def action(inputs, depth, backend, **kwargs):
all_flatten = []

for x in nextinputs:
o, f = x._offsets_and_flattened(1, 1)
o, f = x._offsets_and_flattened(axis=1, depth=1)
c = o.data[1:] - o.data[:-1]
backend.index_nplike.add(counts, c, maybe_out=counts)
all_counts.append(c)
Expand All @@ -316,10 +322,15 @@ def action(inputs, depth, backend, **kwargs):

offsets = ak.index.Index64(offsets, nplike=backend.index_nplike)

if len(nextinputs) < 2**7:
tags_cls = ak.index.Index8
else:
tags_cls = ak.index.Index64
tags, index = ak.contents.UnionArray.nested_tags_index(
offsets,
[ak.index.Index64(x) for x in all_counts],
backend=backend,
tags_cls=tags_cls,
)

inner = ak.contents.UnionArray.simplified(
Expand Down
34 changes: 34 additions & 0 deletions tests/test_2881_ak_concatenate_fails_on_too_many_nested_arrays.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward/blob/main/LICENSE

from __future__ import annotations

import awkward as ak

# from awkward.operations import to_list


def test_concatenate_as_reported():
a = ak.Array([[1]])
a_concat_128 = ak.concatenate([a for i in range(128)], axis=1)
assert a_concat_128.to_list() == [[1] * 128]

a_concat_129 = ak.concatenate([a for i in range(129)], axis=1)
assert a_concat_129.to_list() == [[1] * 129]


def test_concatenate_inner_union_simplify_one():
a = ak.Array([[99]])
astr = ak.Array(["a b c d".split()])
aa = [a for i in range(129)] + [astr]

cu = ak.concatenate(aa, axis=1)
assert cu.to_list() == [[99] * 129 + ["a", "b", "c", "d"]]


def test_concatenate_inner_union_simplify():
a = ak.Array([[99]])
amulti = ak.Array([[1, 2, "a", "b"]])
aa = [a for i in range(129)] + [amulti]

cu = ak.concatenate(aa, axis=1)
assert cu.to_list() == [[99] * 129 + [1, 2, "a", "b"]]

0 comments on commit af232f5

Please sign in to comment.